diff --git a/dump_test.go b/dump_test.go index ea6c1e25..877702d9 100644 --- a/dump_test.go +++ b/dump_test.go @@ -1 +1,62 @@ package req + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestDumpText(t *testing.T) { + SetFlags(LstdFlags | Lcost) + reqBody := "request body" + respBody := "response body" + reqHeader := "Request-Header" + respHeader := "Response-Header" + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(respHeader, "req") + w.Write([]byte(respBody)) + } + ts := httptest.NewServer(http.HandlerFunc(handler)) + header := Header{ + reqHeader: "hello", + } + resp, err := Post(ts.URL, header, reqBody) + if err != nil { + t.Fatal(err) + } + dump := resp.dump() + for _, keyword := range []string{reqBody, respBody, reqHeader, respHeader, "cost"} { + if !strings.Contains(dump, keyword) { + t.Errorf("dump missing part, want: %s", keyword) + } + } +} + +func TestDumpUpload(t *testing.T) { + SetFlags(LreqBody) + file1 := ioutil.NopCloser(strings.NewReader("file1")) + uploads := []FileUpload{ + { + FileName: "1.txt", + FieldName: "media", + File: file1, + }, + } + ts := newDefaultTestServer() + r, err := Post(ts.URL, uploads, Param{"hello": "req"}) + if err != nil { + t.Fatal(err) + } + dump := r.dump() + contains := []string{ + `Content-Disposition: form-data; name="hello"`, + `Content-Disposition: form-data; name="media"; filename="1.txt"`, + } + for _, contain := range contains { + if !strings.Contains(dump, contain) { + t.Errorf("multipart dump should contains: %s", contain) + } + } +} diff --git a/req_test.go b/req_test.go index 1ddeeb3e..926f27d6 100644 --- a/req_test.go +++ b/req_test.go @@ -1,6 +1,7 @@ package req import ( + "bytes" "encoding/json" "encoding/xml" "io/ioutil" @@ -61,6 +62,26 @@ func TestFormParam(t *testing.T) { } } +func TestParamWithBody(t *testing.T) { + reqBody := "request body" + p := Param{ + "name": "roc", + "job": "programmer", + } + buf := bytes.NewBufferString(reqBody) + ts := newDefaultTestServer() + r, err := Post(ts.URL, p, buf) + if err != nil { + t.Fatal(err) + } + if r.Request().URL.Query().Get("name") != "roc" { + t.Error("param should in the url when set body manually") + } + if string(r.reqBody) != reqBody { + t.Error("request body not equal") + } +} + func TestParamBoth(t *testing.T) { urlParam := QueryParam{ "access_token": "123abc", @@ -93,6 +114,46 @@ func TestParamBoth(t *testing.T) { } +func TestBody(t *testing.T) { + body := "request body" + handler := func(w http.ResponseWriter, r *http.Request) { + bs, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + if string(bs) != body { + t.Errorf("body = %s; want = %s", bs, body) + } + } + ts := httptest.NewServer(http.HandlerFunc(handler)) + + // string + _, err := Post(ts.URL, body) + if err != nil { + t.Fatal(err) + } + + // []byte + _, err = Post(ts.URL, []byte(body)) + if err != nil { + t.Fatal(err) + } + + // *bytes.Buffer + var buf bytes.Buffer + buf.WriteString(body) + _, err = Post(ts.URL, &buf) + if err != nil { + t.Fatal(err) + } + + // io.Reader + _, err = Post(ts.URL, strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } +} + func TestBodyJSON(t *testing.T) { type content struct { Code int `json:"code"` @@ -195,6 +256,15 @@ func TestHeader(t *testing.T) { if err != nil { t.Fatal(err) } + + httpHeader := make(http.Header) + for key, value := range header { + httpHeader.Add(key, value) + } + _, err = Head(ts.URL, httpHeader) + if err != nil { + t.Fatal(err) + } } func TestUpload(t *testing.T) { @@ -235,4 +305,9 @@ func TestUpload(t *testing.T) { if err != nil { t.Fatal(err) } + ts = newDefaultTestServer() + _, err = Post(ts.URL, File("*.go")) + if err != nil { + t.Fatal(err) + } } diff --git a/resp.go b/resp.go index 71d4bc0c..ff7b9961 100644 --- a/resp.go +++ b/resp.go @@ -1,12 +1,10 @@ package req import ( - "bytes" "encoding/json" "encoding/xml" "fmt" "io" - "io/ioutil" "net/http" "os" "regexp" @@ -25,13 +23,6 @@ type Resp struct { cost time.Duration } -func (r *Resp) getReqBody() io.ReadCloser { - if r.reqBody == nil { - return nil - } - return ioutil.NopCloser(bytes.NewReader(r.reqBody)) -} - // Cost returns time spent by the request func (r *Resp) Cost() time.Duration { return r.cost diff --git a/resp_test.go b/resp_test.go index ea6c1e25..6e881a3b 100644 --- a/resp_test.go +++ b/resp_test.go @@ -1 +1,130 @@ package req + +import ( + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestToJSON(t *testing.T) { + type Result struct { + Code int `json:"code"` + Msg string `json:"msg"` + } + r1 := Result{ + Code: 1, + Msg: "ok", + } + handler := func(w http.ResponseWriter, r *http.Request) { + data, _ := json.Marshal(&r1) + w.Write(data) + } + ts := httptest.NewServer(http.HandlerFunc(handler)) + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + var r2 Result + err = r.ToJSON(&r2) + if err != nil { + t.Fatal(err) + } + if r1 != r2 { + t.Errorf("json response body = %+v; want = %+v", r2, r1) + } +} + +func TestToXML(t *testing.T) { + type Result struct { + XMLName xml.Name + Code int `xml:"code"` + Msg string `xml:"msg"` + } + r1 := Result{ + XMLName: xml.Name{Local: "result"}, + Code: 1, + Msg: "ok", + } + handler := func(w http.ResponseWriter, r *http.Request) { + data, _ := xml.Marshal(&r1) + w.Write(data) + } + ts := httptest.NewServer(http.HandlerFunc(handler)) + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + var r2 Result + err = r.ToXML(&r2) + if err != nil { + t.Fatal(err) + } + if r1 != r2 { + t.Errorf("xml response body = %+v; want = %+v", r2, r1) + } +} + +func TestFormat(t *testing.T) { + SetFlags(LstdFlags | Lcost) + reqHeader := "Request-Header" + respHeader := "Response-Header" + reqBody := "request body" + respBody1 := "response body 1" + respBody2 := "response body 2" + respBody := fmt.Sprintf("%s\n%s", respBody1, respBody2) + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(respHeader, "req") + w.Write([]byte(respBody)) + } + ts := httptest.NewServer(http.HandlerFunc(handler)) + + // %v + r, err := Post(ts.URL, reqBody, Header{reqHeader: "hello"}) + if err != nil { + t.Fatal(err) + } + str := fmt.Sprintf("%v", r) + for _, keyword := range []string{ts.URL, reqBody, respBody} { + if !strings.Contains(str, keyword) { + t.Errorf("format %%v output lack of part, want: %s", keyword) + } + } + + // %-v + str = fmt.Sprintf("%-v", r) + for _, keyword := range []string{ts.URL, respBody1 + " " + respBody2} { + if !strings.Contains(str, keyword) { + t.Errorf("format %%-v output lack of part, want: %s", keyword) + } + } + + // %+v + str = fmt.Sprintf("%+v", r) + for _, keyword := range []string{reqBody, respBody, reqHeader, respHeader} { + if !strings.Contains(str, keyword) { + t.Errorf("format %%+v output lack of part, want: %s", keyword) + } + } +} + +func TestBytesAndString(t *testing.T) { + respBody := "response body" + handler := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(respBody)) + } + ts := httptest.NewServer(http.HandlerFunc(handler)) + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if string(r.Bytes()) != respBody { + t.Errorf("response body = %s; want = %s", r.Bytes(), respBody) + } + if r.String() != respBody { + t.Errorf("response body = %s; want = %s", r.String(), respBody) + } +} diff --git a/setting_test.go b/setting_test.go new file mode 100644 index 00000000..e71a6d7d --- /dev/null +++ b/setting_test.go @@ -0,0 +1,62 @@ +package req + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func newDefaultTestServer() *httptest.Server { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + } + return httptest.NewServer(http.HandlerFunc(handler)) +} + +func TestSetClient(t *testing.T) { + + ts := newDefaultTestServer() + + client := &http.Client{} + SetClient(client) + _, err := Get(ts.URL) + if err != nil { + t.Errorf("error after set client: %v", err) + } + + SetClient(nil) + _, err = Get(ts.URL) + if err != nil { + t.Errorf("error after set client to nil: %v", err) + } + + client = Client() + if trans, ok := client.Transport.(*http.Transport); ok { + trans.MaxIdleConns = 1 + trans.DisableKeepAlives = true + _, err = Get(ts.URL) + if err != nil { + t.Errorf("error after change client's transport: %v", err) + } + } else { + t.Errorf("transport is not http.Transport: %+#v", client.Transport) + } +} + +func TestSetting(t *testing.T) { + defer func() { + if rc := recover(); rc != nil { + t.Errorf("panic happened while change setting: %v", rc) + } + }() + SetTimeout(2 * time.Second) + EnableCookie(false) + EnableCookie(true) + EnableInsecureTLS(true) + SetJSONIndent("", " ") + SetJSONEscapeHTML(false) + SetXMLIndent("", "\t") + SetProxyUrl("http://localhost:8080") + SetProxy(nil) +}