diff --git a/client.go b/client.go index 28433b76..0efc131f 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,8 @@ package linodego import ( "bytes" "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" @@ -14,13 +16,12 @@ import ( "path/filepath" "reflect" "regexp" + "runtime" "strconv" "strings" "sync" "text/template" "time" - - "github.com/go-resty/resty/v2" ) const ( @@ -63,30 +64,80 @@ Headers: {{.Headers}} Body: {{.Body}}`)) ) +type RequestLog struct { + Method string + URL string + Headers http.Header + Body string +} + +type ResponseLog struct { + Method string + URL string + Headers http.Header + Body string +} + var envDebug = false -// Client is a wrapper around the Resty client +// Client is a wrapper around the http client type Client struct { - resty *resty.Client - userAgent string - debug bool - retryConditionals []RetryConditional - + //nolint:unused + httpClient *http.Client + //nolint:unused + userAgent string + //nolint:unused + debug bool + + //nolint:unused pollInterval time.Duration - baseURL string - apiVersion string - apiProto string + //nolint:unused + baseURL string + //nolint:unused + apiVersion string + //nolint:unused + apiProto string + //nolint:unused + hostURL string + //nolint:unused + header http.Header + //nolint:unused selectedProfile string - loadedProfile string + //nolint:unused + loadedProfile string + //nolint:unused configProfiles map[string]ConfigProfile // Fields for caching endpoint responses - shouldCache bool + //nolint:unused + shouldCache bool + //nolint:unused cacheExpiration time.Duration - cachedEntries map[string]clientCacheEntry + //nolint:unused + cachedEntries map[string]clientCacheEntry + //nolint:unused cachedEntryLock *sync.RWMutex + //nolint:unused + logger Logger + //nolint:unused + requestLog func(*RequestLog) error + //nolint:unused + onBeforeRequest []func(*http.Request) error + //nolint:unused + onAfterResponse []func(*http.Response) error + + //nolint:unused + retryConditionals []RetryConditional + //nolint:unused + retryMaxWaitTime time.Duration + //nolint:unused + retryMinWaitTime time.Duration + //nolint:unused + retryAfter RetryAfter + //nolint:unused + retryCount int } type EnvDefaults struct { @@ -103,13 +154,11 @@ type clientCacheEntry struct { } type ( - Request = resty.Request - Response = resty.Response - Logger = resty.Logger + Request = http.Request + Response = http.Response ) func init() { - // Whether we will enable Resty debugging output if apiDebug, ok := os.LookupEnv("LINODE_DEBUG"); ok { if parsed, err := strconv.ParseBool(apiDebug); err == nil { envDebug = parsed @@ -123,7 +172,7 @@ func init() { // SetUserAgent sets a custom user-agent for HTTP requests func (c *Client) SetUserAgent(ua string) *Client { c.userAgent = ua - c.resty.SetHeader("User-Agent", c.userAgent) + c.SetHeader("User-Agent", c.userAgent) return c } @@ -136,7 +185,7 @@ type RequestParams struct { // Generic helper to execute HTTP requests using the net/http package // // nolint:unused, funlen, gocognit -func (c *httpClient) doRequest(ctx context.Context, method, url string, params RequestParams) error { +func (c *Client) doRequest(ctx context.Context, method, url string, params RequestParams, paginationMutator *func(*http.Request) error) error { var ( req *http.Request bodyBuffer *bytes.Buffer @@ -144,7 +193,7 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R err error ) - for range httpDefaultRetryCount { + for range c.retryCount { req, bodyBuffer, err = c.createRequest(ctx, method, url, params) if err != nil { return err @@ -154,6 +203,15 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R return err } + if paginationMutator != nil { + if err := (*paginationMutator)(req); err != nil { + if c.debug && c.logger != nil { + c.logger.Errorf("failed to mutate before request: %v", err) + } + return fmt.Errorf("failed to mutate before request: %w", err) + } + } + if c.debug && c.logger != nil { c.logRequest(req, method, url, bodyBuffer) } @@ -215,9 +273,10 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R } // nolint:unused -func (c *httpClient) shouldRetry(resp *http.Response, err error) bool { +func (c *Client) shouldRetry(resp *http.Response, err error) bool { for _, retryConditional := range c.retryConditionals { if retryConditional(resp, err) { + log.Printf("[INFO] Received error %v - Retrying", err) return true } } @@ -225,19 +284,19 @@ func (c *httpClient) shouldRetry(resp *http.Response, err error) bool { } // nolint:unused -func (c *httpClient) createRequest(ctx context.Context, method, url string, params RequestParams) (*http.Request, *bytes.Buffer, error) { +func (c *Client) createRequest(ctx context.Context, method, url string, params RequestParams) (*http.Request, *bytes.Buffer, error) { var bodyReader io.Reader var bodyBuffer *bytes.Buffer if params.Body != nil { - bodyBuffer = new(bytes.Buffer) - if err := json.NewEncoder(bodyBuffer).Encode(params.Body); err != nil { + var ok bool + bodyReader, ok = params.Body.(io.Reader) + if !ok { if c.debug && c.logger != nil { - c.logger.Errorf("failed to encode body: %v", err) + c.logger.Errorf("failed to read body: params.Body is not an io.Reader") } - return nil, nil, fmt.Errorf("failed to encode body: %w", err) + return nil, nil, fmt.Errorf("failed to read body: params.Body is not an io.Reader") } - bodyReader = bodyBuffer } req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) @@ -248,17 +307,25 @@ func (c *httpClient) createRequest(ctx context.Context, method, url string, para return nil, nil, fmt.Errorf("failed to create request: %w", err) } + // Set the default headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") if c.userAgent != "" { req.Header.Set("User-Agent", c.userAgent) } + // Set additional headers added to the client + for name, values := range c.header { + for _, value := range values { + req.Header.Set(name, value) + } + } + return req, bodyBuffer, nil } // nolint:unused -func (c *httpClient) applyBeforeRequest(req *http.Request) error { +func (c *Client) applyBeforeRequest(req *http.Request) error { for _, mutate := range c.onBeforeRequest { if err := mutate(req); err != nil { if c.debug && c.logger != nil { @@ -267,11 +334,12 @@ func (c *httpClient) applyBeforeRequest(req *http.Request) error { return fmt.Errorf("failed to mutate before request: %w", err) } } + return nil } // nolint:unused -func (c *httpClient) applyAfterResponse(resp *http.Response) error { +func (c *Client) applyAfterResponse(resp *http.Response) error { for _, mutate := range c.onAfterResponse { if err := mutate(resp); err != nil { if c.debug && c.logger != nil { @@ -284,7 +352,7 @@ func (c *httpClient) applyAfterResponse(resp *http.Response) error { } // nolint:unused -func (c *httpClient) logRequest(req *http.Request, method, url string, bodyBuffer *bytes.Buffer) { +func (c *Client) logRequest(req *http.Request, method, url string, bodyBuffer *bytes.Buffer) { var reqBody string if bodyBuffer != nil { reqBody = bodyBuffer.String() @@ -292,12 +360,21 @@ func (c *httpClient) logRequest(req *http.Request, method, url string, bodyBuffe reqBody = "nil" } + var reqLog = &RequestLog{ + Method: method, + URL: url, + Headers: req.Header, + Body: reqBody, + } + + c.requestLog(reqLog) + var logBuf bytes.Buffer err := reqLogTemplate.Execute(&logBuf, map[string]interface{}{ - "Method": method, - "URL": url, - "Headers": req.Header, - "Body": reqBody, + "Method": reqLog.Method, + "URL": reqLog.URL, + "Headers": reqLog.Headers, + "Body": reqLog.Body, }) if err == nil { c.logger.Debugf(logBuf.String()) @@ -305,7 +382,7 @@ func (c *httpClient) logRequest(req *http.Request, method, url string, bodyBuffe } // nolint:unused -func (c *httpClient) sendRequest(req *http.Request) (*http.Response, error) { +func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { resp, err := c.httpClient.Do(req) if err != nil { if c.debug && c.logger != nil { @@ -317,8 +394,8 @@ func (c *httpClient) sendRequest(req *http.Request) (*http.Response, error) { } // nolint:unused -func (c *httpClient) checkHTTPError(resp *http.Response) error { - _, err := coupleAPIErrorsHTTP(resp, nil) +func (c *Client) checkHTTPError(resp *http.Response) error { + _, err := coupleAPIErrors(resp, nil) if err != nil { if c.debug && c.logger != nil { c.logger.Errorf("received HTTP error: %v", err) @@ -329,7 +406,7 @@ func (c *httpClient) checkHTTPError(resp *http.Response) error { } // nolint:unused -func (c *httpClient) logResponse(resp *http.Response) (*http.Response, error) { +func (c *Client) logResponse(resp *http.Response) (*http.Response, error) { var respBody bytes.Buffer if _, err := io.Copy(&respBody, resp.Body); err != nil { c.logger.Errorf("failed to read response body: %v", err) @@ -350,7 +427,7 @@ func (c *httpClient) logResponse(resp *http.Response) (*http.Response, error) { } // nolint:unused -func (c *httpClient) decodeResponseBody(resp *http.Response, response interface{}) error { +func (c *Client) decodeResponseBody(resp *http.Response, response interface{}) error { if err := json.NewDecoder(resp.Body).Decode(response); err != nil { if c.debug && c.logger != nil { c.logger.Errorf("failed to decode response: %v", err) @@ -360,71 +437,28 @@ func (c *httpClient) decodeResponseBody(resp *http.Response, response interface{ return nil } -// R wraps resty's R method -func (c *Client) R(ctx context.Context) *resty.Request { - return c.resty.R(). - ExpectContentType("application/json"). - SetHeader("Content-Type", "application/json"). - SetContext(ctx). - SetError(APIError{}) -} - -// SetDebug sets the debug on resty's client -func (c *Client) SetDebug(debug bool) *Client { - c.debug = debug - c.resty.SetDebug(debug) - - return c -} - -// SetLogger allows the user to override the output -// logger for debug logs. -func (c *Client) SetLogger(logger Logger) *Client { - c.resty.SetLogger(logger) - - return c -} - //nolint:unused -func (c *httpClient) httpSetDebug(debug bool) *httpClient { +func (c *Client) SetDebug(debug bool) *Client { c.debug = debug return c } //nolint:unused -func (c *httpClient) httpSetLogger(logger httpLogger) *httpClient { +func (c *Client) SetLogger(logger Logger) *Client { c.logger = logger return c } -// OnBeforeRequest adds a handler to the request body to run before the request is sent -func (c *Client) OnBeforeRequest(m func(request *Request) error) { - c.resty.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error { - return m(req) - }) -} - -// OnAfterResponse adds a handler to the request body to run before the request is sent -func (c *Client) OnAfterResponse(m func(response *Response) error) { - c.resty.OnAfterResponse(func(_ *resty.Client, req *resty.Response) error { - return m(req) - }) -} - // nolint:unused -func (c *httpClient) httpOnBeforeRequest(m func(*http.Request) error) *httpClient { +func (c *Client) OnBeforeRequest(m func(*http.Request) error) { c.onBeforeRequest = append(c.onBeforeRequest, m) - - return c } // nolint:unused -func (c *httpClient) httpOnAfterResponse(m func(*http.Response) error) *httpClient { +func (c *Client) OnAfterResponse(m func(*http.Response) error) { c.onAfterResponse = append(c.onAfterResponse, m) - - return c } // UseURL parses the individual components of the given API URL and configures the client @@ -458,7 +492,6 @@ func (c *Client) UseURL(apiURL string) (*Client, error) { return c, nil } -// SetBaseURL sets the base URL of the Linode v4 API (https://api.linode.com/v4) func (c *Client) SetBaseURL(baseURL string) *Client { baseURLPath, _ := url.Parse(baseURL) @@ -496,51 +529,66 @@ func (c *Client) updateHostURL() { apiProto = c.apiProto } - c.resty.SetBaseURL( - fmt.Sprintf( - "%s://%s/%s", - apiProto, - baseURL, - url.PathEscape(apiVersion), - ), - ) + c.hostURL = strings.TrimRight(fmt.Sprintf("%s://%s/%s", apiProto, baseURL, url.PathEscape(apiVersion)), "/") +} + +func (c *Client) Transport() (*http.Transport, error) { + if transport, ok := c.httpClient.Transport.(*http.Transport); ok { + return transport, nil + } + return nil, fmt.Errorf("current transport is not an *http.Transport instance") +} + +func (c *Client) tlsConfig() (*tls.Config, error) { + transport, err := c.Transport() + if err != nil { + return nil, err + } + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + return transport.TLSClientConfig, nil } // SetRootCertificate adds a root certificate to the underlying TLS client config func (c *Client) SetRootCertificate(path string) *Client { - c.resty.SetRootCertificate(path) + config, err := c.tlsConfig() + if err != nil { + c.logger.Errorf("%v", err) + return c + } + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + config.RootCAs.AppendCertsFromPEM([]byte(path)) return c } // SetToken sets the API token for all requests from this client // Only necessary if you haven't already provided the http client to NewClient() configured with the token. func (c *Client) SetToken(token string) *Client { - c.resty.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) + c.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) return c } // SetRetries adds retry conditions for "Linode Busy." errors and 429s. func (c *Client) SetRetries() *Client { c. - addRetryConditional(linodeBusyRetryCondition). - addRetryConditional(tooManyRequestsRetryCondition). - addRetryConditional(serviceUnavailableRetryCondition). - addRetryConditional(requestTimeoutRetryCondition). - addRetryConditional(requestGOAWAYRetryCondition). - addRetryConditional(requestNGINXRetryCondition). + AddRetryCondition(LinodeBusyRetryCondition). + AddRetryCondition(TooManyRequestsRetryCondition). + AddRetryCondition(ServiceUnavailableRetryCondition). + AddRetryCondition(RequestTimeoutRetryCondition). + AddRetryCondition(RequestGOAWAYRetryCondition). + AddRetryCondition(RequestNGINXRetryCondition). SetRetryMaxWaitTime(APIRetryMaxWaitTime) - configureRetries(c) + ConfigureRetries(c) return c } // AddRetryCondition adds a RetryConditional function to the Client func (c *Client) AddRetryCondition(retryCondition RetryConditional) *Client { - c.resty.AddRetryCondition(resty.RetryConditionFunc(retryCondition)) - return c -} - -func (c *Client) addRetryConditional(retryConditional RetryConditional) *Client { - c.retryConditionals = append(c.retryConditionals, retryConditional) + c.retryConditionals = append(c.retryConditionals, retryCondition) return c } @@ -655,26 +703,26 @@ func (c *Client) UseCache(value bool) { // SetRetryMaxWaitTime sets the maximum delay before retrying a request. func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { - c.resty.SetRetryMaxWaitTime(maxWaitTime) + c.retryMaxWaitTime = maxWaitTime return c } // SetRetryWaitTime sets the default (minimum) delay before retrying a request. func (c *Client) SetRetryWaitTime(minWaitTime time.Duration) *Client { - c.resty.SetRetryWaitTime(minWaitTime) + c.retryMinWaitTime = minWaitTime return c } // SetRetryAfter sets the callback function to be invoked with a failed request // to determine wben it should be retried. func (c *Client) SetRetryAfter(callback RetryAfter) *Client { - c.resty.SetRetryAfter(resty.RetryAfterFunc(callback)) + c.retryAfter = callback return c } // SetRetryCount sets the maximum retry attempts before aborting. func (c *Client) SetRetryCount(count int) *Client { - c.resty.SetRetryCount(count) + c.retryCount = count return c } @@ -695,13 +743,29 @@ func (c *Client) GetPollDelay() time.Duration { // client. // NOTE: Some headers may be overridden by the individual request functions. func (c *Client) SetHeader(name, value string) { - c.resty.SetHeader(name, value) + if c.header == nil { + c.header = make(http.Header) // Initialize header if nil + } + c.header.Set(name, value) +} + +func (c *Client) onRequestLog(rl func(*RequestLog) error) *Client { + if c.requestLog != nil { + c.logger.Warnf("Overwriting an existing on-request-log callback from=%s to=%s", + functionName(c.requestLog), functionName(rl)) + } + c.requestLog = rl + return c +} + +func functionName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } func (c *Client) enableLogSanitization() *Client { - c.resty.OnRequestLog(func(r *resty.RequestLog) error { + c.onRequestLog(func(r *RequestLog) error { // masking authorization header - r.Header.Set("Authorization", "Bearer *******************************") + r.Headers.Set("Authorization", "Bearer *******************************") return nil }) @@ -711,20 +775,25 @@ func (c *Client) enableLogSanitization() *Client { // NewClient factory to create new Client struct func NewClient(hc *http.Client) (client Client) { if hc != nil { - client.resty = resty.NewWithClient(hc) + client.httpClient = hc } else { - client.resty = resty.New() + client.httpClient = &http.Client{} + } + + // Ensure that the Header map is not nil + if client.httpClient.Transport == nil { + client.httpClient.Transport = &http.Transport{} } client.shouldCache = true client.cacheExpiration = APIDefaultCacheExpiration client.cachedEntries = make(map[string]clientCacheEntry) client.cachedEntryLock = &sync.RWMutex{} + client.configProfiles = make(map[string]ConfigProfile) client.SetUserAgent(DefaultUserAgent) baseURL, baseURLExists := os.LookupEnv(APIHostVar) - if baseURLExists { client.SetBaseURL(baseURL) } @@ -736,7 +805,6 @@ func NewClient(hc *http.Client) (client Client) { } certPath, certPathExists := os.LookupEnv(APIHostCert) - if certPathExists { cert, err := os.ReadFile(filepath.Clean(certPath)) if err != nil { diff --git a/client_http.go b/client_http.go deleted file mode 100644 index 7f16362c..00000000 --- a/client_http.go +++ /dev/null @@ -1,56 +0,0 @@ -package linodego - -import ( - "net/http" - "sync" - "time" -) - -// Client is a wrapper around the Resty client -// -//nolint:unused -type httpClient struct { - //nolint:unused - httpClient *http.Client - //nolint:unused - userAgent string - //nolint:unused - debug bool - //nolint:unused - retryConditionals []httpRetryConditional - //nolint:unused - retryAfter httpRetryAfter - - //nolint:unused - pollInterval time.Duration - - //nolint:unused - baseURL string - //nolint:unused - apiVersion string - //nolint:unused - apiProto string - //nolint:unused - selectedProfile string - //nolint:unused - loadedProfile string - - //nolint:unused - configProfiles map[string]ConfigProfile - - // Fields for caching endpoint responses - //nolint:unused - shouldCache bool - //nolint:unused - cacheExpiration time.Duration - //nolint:unused - cachedEntries map[string]clientCacheEntry - //nolint:unused - cachedEntryLock *sync.RWMutex - //nolint:unused - logger httpLogger - //nolint:unused - onBeforeRequest []func(*http.Request) error - //nolint:unused - onAfterResponse []func(*http.Response) error -} diff --git a/client_test.go b/client_test.go index 93b133f9..bf58b5a7 100644 --- a/client_test.go +++ b/client_test.go @@ -5,6 +5,8 @@ import ( "context" "errors" "fmt" + "github.com/jarcoal/httpmock" + "github.com/linode/linodego/internal/testutil" "net/http" "net/http/httptest" "reflect" @@ -12,8 +14,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/jarcoal/httpmock" - "github.com/linode/linodego/internal/testutil" ) func TestClient_SetAPIVersion(t *testing.T) { @@ -33,39 +33,39 @@ func TestClient_SetAPIVersion(t *testing.T) { client := NewClient(nil) - if client.resty.BaseURL != defaultURL { - t.Fatal(cmp.Diff(client.resty.BaseURL, defaultURL)) + if client.hostURL != defaultURL { + t.Fatal(cmp.Diff(client.hostURL, defaultURL)) } client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Ensure setting twice does not cause conflicts client.SetBaseURL(updatedBaseURL) client.SetAPIVersion(updatedAPIVersion) - if client.resty.BaseURL != updatedExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, updatedExpectedHost)) + if client.hostURL != updatedExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, updatedExpectedHost)) } // Revert client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Custom protocol client.SetBaseURL(protocolBaseURL) client.SetAPIVersion(protocolAPIVersion) - if client.resty.BaseURL != protocolExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != protocolExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } } @@ -107,7 +107,7 @@ func TestClient_NewFromEnvToken(t *testing.T) { t.Fatal(err) } - if client.resty.Header.Get("Authorization") != "Bearer blah" { + if client.header.Get("Authorization") != "Bearer blah" { t.Fatal("token not found in auth header: blah") } } @@ -171,16 +171,16 @@ func TestDebugLogSanitization(t *testing.T) { logger.L.SetOutput(&lgr) mockClient.SetDebug(true) - if !mockClient.resty.Debug { + if !mockClient.debug { t.Fatal("debug should be enabled") } mockClient.SetHeader("Authorization", fmt.Sprintf("Bearer %s", plainTextToken)) - if mockClient.resty.Header.Get("Authorization") != fmt.Sprintf("Bearer %s", plainTextToken) { + if mockClient.header.Get("Authorization") != fmt.Sprintf("Bearer %s", plainTextToken) { t.Fatal("token not found in auth header") } - httpmock.RegisterRegexpResponder("GET", testutil.MockRequestURL("/linode/instances"), + httpmock.RegisterResponder("GET", "/linode/instances", httpmock.NewJsonResponderOrPanic(200, &testResp)) result, err := doGETRequest[instanceResponse]( @@ -211,15 +211,13 @@ func TestDoRequest_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) params := RequestParams{ Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, params, nil) if err != nil { t.Fatal(cmp.Diff(nil, err)) } @@ -231,10 +229,8 @@ func TestDoRequest_Success(t *testing.T) { } } -func TestDoRequest_FailedEncodeBody(t *testing.T) { - client := &httpClient{ - httpClient: http.DefaultClient, - } +func TestDoRequest_FailedReadBody(t *testing.T) { + client := NewClient(nil) params := RequestParams{ Body: map[string]interface{}{ @@ -242,20 +238,18 @@ func TestDoRequest_FailedEncodeBody(t *testing.T) { }, } - err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params) - expectedErr := "failed to encode body" + err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params, nil) + expectedErr := "failed to read body: params.Body is not an io.Reader" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } } func TestDoRequest_FailedCreateRequest(t *testing.T) { - client := &httpClient{ - httpClient: http.DefaultClient, - } + client := NewClient(nil) // Create a request with an invalid URL to simulate a request creation failure - err := client.doRequest(context.Background(), http.MethodGet, "http://invalid url", RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "http://invalid url", RequestParams{}, nil) expectedErr := "failed to create request" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -269,14 +263,13 @@ func TestDoRequest_Non2xxStatusCode(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}, nil) if err == nil { t.Fatal("expected error, got nil") } + httpError, ok := err.(Error) if !ok { t.Fatalf("expected error to be of type Error, got %T", err) @@ -298,15 +291,13 @@ func TestDoRequest_FailedDecodeResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) params := RequestParams{ Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, params, nil) expectedErr := "failed to decode response" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -325,9 +316,7 @@ func TestDoRequest_BeforeRequestSuccess(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) // Define a mutator that successfully modifies the request mutator := func(req *http.Request) error { @@ -335,9 +324,9 @@ func TestDoRequest_BeforeRequestSuccess(t *testing.T) { return nil } - client.httpOnBeforeRequest(mutator) + client.OnBeforeRequest(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}, nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -357,17 +346,15 @@ func TestDoRequest_BeforeRequestError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) mutator := func(req *http.Request) error { return errors.New("mutator error") } - client.httpOnBeforeRequest(mutator) + client.OnBeforeRequest(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}, nil) expectedErr := "failed to mutate before request" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -387,18 +374,16 @@ func TestDoRequest_AfterResponseSuccess(t *testing.T) { tr := &testRoundTripper{ Transport: server.Client().Transport, } - client := &httpClient{ - httpClient: &http.Client{Transport: tr}, - } + client := NewClient(&http.Client{Transport: tr}) mutator := func(resp *http.Response) error { resp.Header.Set("X-Modified-Header", "ModifiedValue") return nil } - client.httpOnAfterResponse(mutator) + client.OnAfterResponse(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}, nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -418,17 +403,15 @@ func TestDoRequest_AfterResponseError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) mutator := func(resp *http.Response) error { return errors.New("mutator error") } - client.httpOnAfterResponse(mutator) + client.OnAfterResponse(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}, nil) expectedErr := "failed to mutate after response" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -440,11 +423,9 @@ func TestDoRequestLogging_Success(t *testing.T) { logger := createLogger() logger.l.SetOutput(&logBuffer) // Redirect log output to buffer - client := &httpClient{ - httpClient: http.DefaultClient, - debug: true, - logger: logger, - } + client := NewClient(nil) + client.SetDebug(true) + client.SetLogger(logger) handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -458,7 +439,7 @@ func TestDoRequestLogging_Success(t *testing.T) { Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, params, nil) if err != nil { t.Fatal(cmp.Diff(nil, err)) } @@ -467,8 +448,8 @@ func TestDoRequestLogging_Success(t *testing.T) { logInfoWithoutTimestamps := removeTimestamps(logInfo) // Expected logs with templates filled in - expectedRequestLog := "DEBUG RESTY Sending request:\nMethod: GET\nURL: " + server.URL + "\nHeaders: map[Accept:[application/json] Content-Type:[application/json]]\nBody: " - expectedResponseLog := "DEBUG RESTY Received response:\nStatus: 200 OK\nHeaders: map[Content-Length:[21] Content-Type:[text/plain; charset=utf-8]]\nBody: {\"message\":\"success\"}" + expectedRequestLog := "DEBUG Sending request:\nMethod: GET\nURL: " + server.URL + "\nHeaders: map[Accept:[application/json] Authorization:[Bearer *******************************] Content-Type:[application/json] User-Agent:[linodego/dev https://github.com/linode/linodego]]\nBody: " + expectedResponseLog := "DEBUG Received response:\nStatus: 200 OK\nHeaders: map[Content-Length:[21] Content-Type:[text/plain; charset=utf-8]]\nBody: {\"message\":\"success\"}" if !strings.Contains(logInfo, expectedRequestLog) { t.Fatalf("expected log %q not found in logs", expectedRequestLog) @@ -483,11 +464,9 @@ func TestDoRequestLogging_Error(t *testing.T) { logger := createLogger() logger.l.SetOutput(&logBuffer) // Redirect log output to buffer - client := &httpClient{ - httpClient: http.DefaultClient, - debug: true, - logger: logger, - } + client := NewClient(nil) + client.SetDebug(true) + client.SetLogger(logger) params := RequestParams{ Body: map[string]interface{}{ @@ -495,14 +474,14 @@ func TestDoRequestLogging_Error(t *testing.T) { }, } - err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params) - expectedErr := "failed to encode body" + err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params, nil) + expectedErr := "failed to read body: params.Body is not an io.Reader" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } logInfo := logBuffer.String() - expectedLog := "ERROR RESTY failed to encode body" + expectedLog := "ERROR failed to read body: params.Body is not an io.Reader" if !strings.Contains(logInfo, expectedLog) { t.Fatalf("expected log %q not found in logs", expectedLog) diff --git a/config_test.go b/config_test.go index b4b3db41..628cc141 100644 --- a/config_test.go +++ b/config_test.go @@ -42,11 +42,11 @@ func TestConfig_LoadWithDefaults(t *testing.T) { expectedURL := "https://api.cool.linode.com/v4beta" - if client.resty.BaseURL != expectedURL { - t.Fatalf("mismatched host url: %s != %s", client.resty.BaseURL, expectedURL) + if client.hostURL != expectedURL { + t.Fatalf("mismatched host url: %s != %s", client.hostURL, expectedURL) } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } @@ -88,11 +88,11 @@ func TestConfig_OverrideDefaults(t *testing.T) { expectedURL := "https://api.cool.linode.com/v4" - if client.resty.BaseURL != expectedURL { - t.Fatalf("mismatched host url: %s != %s", client.resty.BaseURL, expectedURL) + if client.hostURL != expectedURL { + t.Fatalf("mismatched host url: %s != %s", client.hostURL, expectedURL) } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } @@ -124,7 +124,7 @@ func TestConfig_NoDefaults(t *testing.T) { t.Fatalf("mismatched api token: %s != %s", p.APIToken, "mytoken") } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } diff --git a/errors.go b/errors.go index be15c014..e3ed7a46 100644 --- a/errors.go +++ b/errors.go @@ -53,51 +53,8 @@ func (r APIErrorReason) String() string { return fmt.Sprintf("[%s] %s", r.Field, r.Reason) } -func coupleAPIErrors(r *resty.Response, err error) (*resty.Response, error) { - if err != nil { - // an error was raised in go code, no need to check the resty Response - return nil, NewError(err) - } - - if r.Error() == nil { - // no error in the resty Response - return r, nil - } - - // handle the resty Response errors - - // Check that response is of the correct content-type before unmarshalling - expectedContentType := r.Request.Header.Get("Accept") - responseContentType := r.Header().Get("Content-Type") - - // If the upstream Linode API server being fronted fails to respond to the request, - // the http server will respond with a default "Bad Gateway" page with Content-Type - // "text/html". - if r.StatusCode() == http.StatusBadGateway && responseContentType == "text/html" { //nolint:goconst - return nil, Error{Code: http.StatusBadGateway, Message: http.StatusText(http.StatusBadGateway)} - } - - if responseContentType != expectedContentType { - msg := fmt.Sprintf( - "Unexpected Content-Type: Expected: %v, Received: %v\nResponse body: %s", - expectedContentType, - responseContentType, - string(r.Body()), - ) - - return nil, Error{Code: r.StatusCode(), Message: msg} - } - - apiError, ok := r.Error().(*APIError) - if !ok || (ok && len(apiError.Errors) == 0) { - return r, nil - } - - return nil, NewError(r) -} - //nolint:unused -func coupleAPIErrorsHTTP(resp *http.Response, err error) (*http.Response, error) { +func coupleAPIErrors(resp *http.Response, err error) (*http.Response, error) { if err != nil { // an error was raised in go code, no need to check the http.Response return nil, NewError(err) diff --git a/errors_test.go b/errors_test.go index 68428fcd..19ea5105 100644 --- a/errors_test.go +++ b/errors_test.go @@ -119,98 +119,6 @@ func TestCoupleAPIErrors(t *testing.T) { } }) - t.Run("resty 500 response error with reasons", func(t *testing.T) { - if _, err := coupleAPIErrors(restyError("testreason", "testfield"), nil); err.Error() != "[500] [testfield] testreason" { - t.Error("resty error should return with proper format [code] [field] reason") - } - }) - - t.Run("resty 500 response error without reasons", func(t *testing.T) { - if _, err := coupleAPIErrors(restyError("", ""), nil); err != nil { - t.Error("resty error with no reasons should return no error") - } - }) - - t.Run("resty response with nil error", func(t *testing.T) { - emptyErr := &resty.Response{ - RawResponse: &http.Response{ - StatusCode: 500, - }, - Request: &resty.Request{ - Error: nil, - }, - } - if _, err := coupleAPIErrors(emptyErr, nil); err != nil { - t.Error("resty error with no reasons should return no error") - } - }) - - t.Run("generic html error", func(t *testing.T) { - rawResponse := ` -500 Internal Server Error - -

500 Internal Server Error

-
nginx
- -` - route := "/v4/linode/instances/123" - ts, client := createTestServer(http.MethodGet, route, "text/html", rawResponse, http.StatusInternalServerError) - // client.SetDebug(true) - defer ts.Close() - - expectedError := Error{ - Code: http.StatusInternalServerError, - Message: "Unexpected Content-Type: Expected: application/json, Received: text/html\nResponse body: " + rawResponse, - } - - _, err := coupleAPIErrors(client.R(context.Background()).SetResult(&Instance{}).Get(ts.URL + route)) - if diff := cmp.Diff(expectedError, err); diff != "" { - t.Errorf("expected error to match but got diff:\n%s", diff) - } - }) - - t.Run("bad gateway error", func(t *testing.T) { - rawResponse := []byte(` -502 Bad Gateway - -

502 Bad Gateway

-
nginx
- -`) - buf := io.NopCloser(bytes.NewBuffer(rawResponse)) - - resp := &resty.Response{ - Request: &resty.Request{ - Error: errors.New("Bad Gateway"), - }, - RawResponse: &http.Response{ - Header: http.Header{ - "Content-Type": []string{"text/html"}, - }, - StatusCode: http.StatusBadGateway, - Body: buf, - }, - } - - expectedError := Error{ - Code: http.StatusBadGateway, - Message: http.StatusText(http.StatusBadGateway), - } - - if _, err := coupleAPIErrors(resp, nil); !cmp.Equal(err, expectedError) { - t.Errorf("expected error %#v to match error %#v", err, expectedError) - } - }) -} - -func TestCoupleAPIErrorsHTTP(t *testing.T) { - t.Run("not nil error generates error", func(t *testing.T) { - err := errors.New("test") - if _, err := coupleAPIErrorsHTTP(nil, err); !cmp.Equal(err, NewError(err)) { - t.Errorf("expect a not nil error to be returned as an Error") - } - }) - t.Run("http 500 response error with reasons", func(t *testing.T) { // Create the simulated HTTP response with a 500 status and a JSON body containing the error details apiError := APIError{ @@ -228,7 +136,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { Request: &http.Request{Header: http.Header{"Accept": []string{"application/json"}}}, } - _, err := coupleAPIErrorsHTTP(resp, nil) + _, err := coupleAPIErrors(resp, nil) expectedMessage := "[500] [testfield] testreason" if err == nil || err.Error() != expectedMessage { t.Errorf("expected error message %q, got: %v", expectedMessage, err) @@ -250,7 +158,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { Request: &http.Request{Header: http.Header{"Accept": []string{"application/json"}}}, } - _, err := coupleAPIErrorsHTTP(resp, nil) + _, err := coupleAPIErrors(resp, nil) if err != nil { t.Error("http error with no reasons should return no error") } @@ -265,7 +173,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { Request: &http.Request{Header: http.Header{"Accept": []string{"application/json"}}}, } - _, err := coupleAPIErrorsHTTP(resp, nil) + _, err := coupleAPIErrors(resp, nil) if err != nil { t.Error("http error with no reasons should return no error") } @@ -288,7 +196,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { })) defer ts.Close() - client := &httpClient{ + client := &Client{ httpClient: ts.Client(), } @@ -310,7 +218,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { } defer resp.Body.Close() - _, err = coupleAPIErrorsHTTP(resp, nil) + _, err = coupleAPIErrors(resp, nil) if diff := cmp.Diff(expectedError, err); diff != "" { t.Errorf("expected error to match but got diff:\n%s", diff) } @@ -342,7 +250,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { Message: http.StatusText(http.StatusBadGateway), } - _, err := coupleAPIErrorsHTTP(resp, nil) + _, err := coupleAPIErrors(resp, nil) if !cmp.Equal(err, expectedError) { t.Errorf("expected error %#v to match error %#v", err, expectedError) } diff --git a/images.go b/images.go index dd1d5f3f..1bd68aba 100644 --- a/images.go +++ b/images.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" "io" + "net/http" "time" - "github.com/go-resty/resty/v2" "github.com/linode/linodego/internal/parseabletime" ) @@ -218,17 +218,22 @@ func (c *Client) CreateImageUpload(ctx context.Context, opts ImageCreateUploadOp // UploadImageToURL uploads the given image to the given upload URL. func (c *Client) UploadImageToURL(ctx context.Context, uploadURL string, image io.Reader) error { - // Linode-specific headers do not need to be sent to this endpoint - req := resty.New().SetDebug(c.resty.Debug).R(). - SetContext(ctx). - SetContentLength(true). - SetHeader("Content-Type", "application/octet-stream"). - SetBody(image) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadURL, image) + if err != nil { + return err + } - _, err := coupleAPIErrors(req. - Put(uploadURL)) + req.Header.Set("Content-Type", "application/octet-stream") + req.ContentLength = -1 // Automatically calculate content length - return err + resp, err := c.httpClient.Do(req) + + _, err = coupleAPIErrors(resp, err) + if err != nil { + return err + } + + return nil } // UploadImage creates and uploads an image. diff --git a/internal/testutil/mock.go b/internal/testutil/mock.go index aaaeebb4..a033f6d1 100644 --- a/internal/testutil/mock.go +++ b/internal/testutil/mock.go @@ -108,15 +108,15 @@ type TestLogger struct { } func (l *TestLogger) Errorf(format string, v ...interface{}) { - l.outputf("ERROR RESTY "+format, v...) + l.outputf("ERROR "+format, v...) } func (l *TestLogger) Warnf(format string, v ...interface{}) { - l.outputf("WARN RESTY "+format, v...) + l.outputf("WARN "+format, v...) } func (l *TestLogger) Debugf(format string, v ...interface{}) { - l.outputf("DEBUG RESTY "+format, v...) + l.outputf("DEBUG "+format, v...) } func (l *TestLogger) outputf(format string, v ...interface{}) { diff --git a/logger.go b/logger.go index 5de75859..4f71e6c3 100644 --- a/logger.go +++ b/logger.go @@ -6,7 +6,7 @@ import ( ) //nolint:unused -type httpLogger interface { +type Logger interface { Errorf(format string, v ...interface{}) Warnf(format string, v ...interface{}) Debugf(format string, v ...interface{}) @@ -24,21 +24,21 @@ func createLogger() *logger { } //nolint:unused -var _ httpLogger = (*logger)(nil) +var _ Logger = (*logger)(nil) //nolint:unused func (l *logger) Errorf(format string, v ...interface{}) { - l.output("ERROR RESTY "+format, v...) + l.output("ERROR "+format, v...) } //nolint:unused func (l *logger) Warnf(format string, v ...interface{}) { - l.output("WARN RESTY "+format, v...) + l.output("WARN "+format, v...) } //nolint:unused func (l *logger) Debugf(format string, v ...interface{}) { - l.output("DEBUG RESTY "+format, v...) + l.output("DEBUG "+format, v...) } //nolint:unused diff --git a/pagination.go b/pagination.go index 3b3f50ac..e8127a58 100644 --- a/pagination.go +++ b/pagination.go @@ -9,10 +9,9 @@ import ( "encoding/hex" "encoding/json" "fmt" + "net/http" "reflect" "strconv" - - "github.com/go-resty/resty/v2" ) // PageOptions are the pagination parameters for List endpoints @@ -56,38 +55,48 @@ func (l ListOptions) Hash() (string, error) { return hex.EncodeToString(h.Sum(nil)), nil } -func applyListOptionsToRequest(opts *ListOptions, req *resty.Request) error { +func createListOptionsToRequestMutator(opts *ListOptions) func(*http.Request) error { if opts == nil { return nil } - if opts.QueryParams != nil { - params, err := flattenQueryStruct(opts.QueryParams) - if err != nil { - return fmt.Errorf("failed to apply list options: %w", err) + // Return a mutator to apply query parameters and headers + return func(req *http.Request) error { + query := req.URL.Query() + + // Apply QueryParams from ListOptions if present + if opts.QueryParams != nil { + params, err := flattenQueryStruct(opts.QueryParams) + if err != nil { + return fmt.Errorf("failed to apply list options: %w", err) + } + for key, value := range params { + query.Set(key, value) + } } - req.SetQueryParams(params) - } - - if opts.PageOptions != nil && opts.Page > 0 { - req.SetQueryParam("page", strconv.Itoa(opts.Page)) - } + // Apply pagination options + if opts.PageOptions != nil && opts.Page > 0 { + query.Set("page", strconv.Itoa(opts.Page)) + } + if opts.PageSize > 0 { + query.Set("page_size", strconv.Itoa(opts.PageSize)) + } - if opts.PageSize > 0 { - req.SetQueryParam("page_size", strconv.Itoa(opts.PageSize)) - } + // Apply filters as headers + if len(opts.Filter) > 0 { + req.Header.Set("X-Filter", opts.Filter) + } - if len(opts.Filter) > 0 { - req.SetHeader("X-Filter", opts.Filter) + // Assign the updated query back to the request URL + req.URL.RawQuery = query.Encode() + return nil } - - return nil } type PagedResponse interface { endpoint(...any) string - castResult(*resty.Request, string) (int, int, error) + castResult(*http.Request, string) (int, int, error) } // flattenQueryStruct flattens a structure into a Resty-compatible query param map. diff --git a/request_helpers.go b/request_helpers.go index 152a2643..522b7722 100644 --- a/request_helpers.go +++ b/request_helpers.go @@ -1,9 +1,11 @@ package linodego import ( + "bytes" "context" "encoding/json" "fmt" + "net/http" "net/url" "reflect" ) @@ -26,8 +28,6 @@ func getPaginatedResults[T any]( endpoint string, opts *ListOptions, ) ([]T, error) { - var resultType paginatedResponse[T] - result := make([]T, 0) if opts == nil { @@ -41,34 +41,33 @@ func getPaginatedResults[T any]( // Makes a request to a particular page and // appends the response to the result handlePage := func(page int) error { - // Override the page to be applied in applyListOptionsToRequest(...) + var resultType paginatedResponse[T] opts.Page = page - // This request object cannot be reused for each page request - // because it can lead to possible data corruption - req := client.R(ctx).SetResult(resultType) - - // Apply all user-provided list options to the request - if err := applyListOptionsToRequest(opts, req); err != nil { - return err + params := RequestParams{ + Response: &resultType, } - res, err := coupleAPIErrors(req.Get(endpoint)) + // Create a mutator to all user-provided list options to the request + mutator := createListOptionsToRequestMutator(opts) + + // Make the request using doRequest + err := client.doRequest(ctx, http.MethodGet, endpoint, params, &mutator) if err != nil { return err } - response := res.Result().(*paginatedResponse[T]) - + // Extract the result from the response opts.Page = page - opts.Pages = response.Pages - opts.Results = response.Results + opts.Pages = resultType.Pages + opts.Results = resultType.Results - result = append(result, response.Data...) + // Append the data to the result slice + result = append(result, resultType.Data...) return nil } - // This helps simplify the logic below + // Determine starting page startingPage := 1 pageDefined := opts.Page > 0 @@ -81,13 +80,12 @@ func getPaginatedResults[T any]( return nil, err } - // If the user has explicitly specified a page, we don't - // need to get any other pages. + // If a specific page is defined, return the result if pageDefined { return result, nil } - // Get the rest of the pages + // Get the remaining pages for page := 2; page <= opts.Pages; page++ { if err := handlePage(page); err != nil { return nil, err @@ -105,14 +103,16 @@ func doGETRequest[T any]( endpoint string, ) (*T, error) { var resultType T + params := RequestParams{ + Response: &resultType, + } - req := client.R(ctx).SetResult(&resultType) - r, err := coupleAPIErrors(req.Get(endpoint)) + err := client.doRequest(ctx, http.MethodGet, endpoint, params, nil) if err != nil { return nil, err } - return r.Result().(*T), nil + return &resultType, nil } // doPOSTRequest runs a PUT request using the given client, API endpoint, @@ -124,29 +124,27 @@ func doPOSTRequest[T, O any]( options ...O, ) (*T, error) { var resultType T - numOpts := len(options) - if numOpts > 1 { - return nil, fmt.Errorf("invalid number of options: %d", len(options)) + return nil, fmt.Errorf("invalid number of options: %d", numOpts) } - req := client.R(ctx).SetResult(&resultType) - + params := RequestParams{ + Response: &resultType, + } if numOpts > 0 && !isNil(options[0]) { body, err := json.Marshal(options[0]) if err != nil { return nil, err } - req.SetBody(string(body)) + params.Body = bytes.NewReader(body) } - r, err := coupleAPIErrors(req.Post(endpoint)) + err := client.doRequest(ctx, http.MethodPost, endpoint, params, nil) if err != nil { return nil, err } - - return r.Result().(*T), nil + return &resultType, nil } // doPUTRequest runs a PUT request using the given client, API endpoint, @@ -158,29 +156,27 @@ func doPUTRequest[T, O any]( options ...O, ) (*T, error) { var resultType T - numOpts := len(options) - if numOpts > 1 { - return nil, fmt.Errorf("invalid number of options: %d", len(options)) + return nil, fmt.Errorf("invalid number of options: %d", numOpts) } - req := client.R(ctx).SetResult(&resultType) - + params := RequestParams{ + Response: &resultType, + } if numOpts > 0 && !isNil(options[0]) { body, err := json.Marshal(options[0]) if err != nil { return nil, err } - req.SetBody(string(body)) + params.Body = bytes.NewReader(body) } - r, err := coupleAPIErrors(req.Put(endpoint)) + err := client.doRequest(ctx, http.MethodPut, endpoint, params, nil) if err != nil { return nil, err } - - return r.Result().(*T), nil + return &resultType, nil } // doDELETERequest runs a DELETE request using the given client @@ -190,8 +186,8 @@ func doDELETERequest( client *Client, endpoint string, ) error { - req := client.R(ctx) - _, err := coupleAPIErrors(req.Delete(endpoint)) + params := RequestParams{} + err := client.doRequest(ctx, http.MethodDelete, endpoint, params, nil) return err } diff --git a/request_helpers_test.go b/request_helpers_test.go index bed2e740..db8bafe4 100644 --- a/request_helpers_test.go +++ b/request_helpers_test.go @@ -3,14 +3,13 @@ package linodego import ( "context" "fmt" + "github.com/stretchr/testify/require" "math" "net/http" "reflect" "strconv" "testing" - "github.com/stretchr/testify/require" - "github.com/linode/linodego/internal/testutil" "github.com/google/go-cmp/cmp" @@ -41,7 +40,7 @@ var testResponse = testResultType{ func TestRequestHelpers_get(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder("GET", testutil.MockRequestURL("/foo/bar"), + httpmock.RegisterResponder("GET", "/foo/bar", httpmock.NewJsonResponderOrPanic(200, &testResponse)) result, err := doGETRequest[testResultType]( @@ -61,7 +60,7 @@ func TestRequestHelpers_get(t *testing.T) { func TestRequestHelpers_post(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder("POST", testutil.MockRequestURL("/foo/bar"), + httpmock.RegisterResponder("POST", "/foo/bar", testutil.MockRequestBodyValidate(t, testResponse, testResponse)) result, err := doPOSTRequest[testResultType]( @@ -82,11 +81,8 @@ func TestRequestHelpers_post(t *testing.T) { func TestRequestHelpers_postNoOptions(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder( - "POST", - testutil.MockRequestURL("/foo/bar"), - testutil.MockRequestBodyValidateNoBody(t, testResponse), - ) + httpmock.RegisterResponder("POST", "/foo/bar", + testutil.MockRequestBodyValidateNoBody(t, testResponse)) result, err := doPOSTRequest[testResultType, any]( context.Background(), @@ -105,7 +101,7 @@ func TestRequestHelpers_postNoOptions(t *testing.T) { func TestRequestHelpers_put(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder("PUT", testutil.MockRequestURL("/foo/bar"), + httpmock.RegisterResponder("PUT", "/foo/bar", testutil.MockRequestBodyValidate(t, testResponse, testResponse)) result, err := doPUTRequest[testResultType]( @@ -126,11 +122,8 @@ func TestRequestHelpers_put(t *testing.T) { func TestRequestHelpers_putNoOptions(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder( - "PUT", - testutil.MockRequestURL("/foo/bar"), - testutil.MockRequestBodyValidateNoBody(t, testResponse), - ) + httpmock.RegisterResponder("PUT", "/foo/bar", + testutil.MockRequestBodyValidateNoBody(t, testResponse)) result, err := doPUTRequest[testResultType, any]( context.Background(), @@ -149,11 +142,8 @@ func TestRequestHelpers_putNoOptions(t *testing.T) { func TestRequestHelpers_delete(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder( - "DELETE", - testutil.MockRequestURL("/foo/bar/foo%20bar"), - httpmock.NewStringResponder(200, "{}"), - ) + httpmock.RegisterResponder("DELETE", "/foo/bar/foo%20bar", + httpmock.NewStringResponder(200, "{}")) if err := doDELETERequest( context.Background(), @@ -171,14 +161,8 @@ func TestRequestHelpers_paginateAll(t *testing.T) { numRequests := 0 - httpmock.RegisterRegexpResponder( - "GET", - testutil.MockRequestURL("/foo/bar"), - mockPaginatedResponse( - buildPaginatedEntries(totalResults), - &numRequests, - ), - ) + httpmock.RegisterResponder("GET", "/foo/bar", + mockPaginatedResponse(buildPaginatedEntries(totalResults), &numRequests)) response, err := getPaginatedResults[testResultType]( context.Background(), @@ -207,14 +191,8 @@ func TestRequestHelpers_paginateSingle(t *testing.T) { numRequests := 0 - httpmock.RegisterRegexpResponder( - "GET", - testutil.MockRequestURL("/foo/bar"), - mockPaginatedResponse( - buildPaginatedEntries(12), - &numRequests, - ), - ) + httpmock.RegisterResponder("GET", "/foo/bar", + mockPaginatedResponse(buildPaginatedEntries(12), &numRequests)) response, err := getPaginatedResults[testResultType]( context.Background(), diff --git a/retries.go b/retries.go index 14886442..4c63f998 100644 --- a/retries.go +++ b/retries.go @@ -1,72 +1,94 @@ package linodego import ( + "encoding/json" "errors" "log" "net/http" "strconv" "time" - "github.com/go-resty/resty/v2" "golang.org/x/net/http2" ) const ( - retryAfterHeaderName = "Retry-After" - maintenanceModeHeaderName = "X-Maintenance-Mode" + // nolint:unused + RetryAfterHeaderName = "Retry-After" + // nolint:unused + MaintenanceModeHeaderName = "X-Maintenance-Mode" - defaultRetryCount = 1000 + // nolint:unused + DefaultRetryCount = 1000 ) -// type RetryConditional func(r *resty.Response) (shouldRetry bool) -type RetryConditional resty.RetryConditionFunc +// RetryConditional is a type alias for a function that determines if a request should be retried based on the response and error. +// nolint:unused +type RetryConditional func(*http.Response, error) bool -// type RetryAfter func(c *resty.Client, r *resty.Response) (time.Duration, error) -type RetryAfter resty.RetryAfterFunc +// RetryAfter is a type alias for a function that determines the duration to wait before retrying based on the response. +// nolint:unused +type RetryAfter func(*http.Response) (time.Duration, error) -// Configures resty to -// lock until enough time has passed to retry the request as determined by the Retry-After response header. -// If the Retry-After header is not set, we fall back to value of SetPollDelay. -func configureRetries(c *Client) { - c.resty. - SetRetryCount(defaultRetryCount). - AddRetryCondition(checkRetryConditionals(c)). - SetRetryAfter(respectRetryAfter) +// Configures http.Client to lock until enough time has passed to retry the request as determined by the Retry-After response header. +// If the Retry-After header is not set, we fall back to the value of SetPollDelay. +// nolint:unused +func ConfigureRetries(c *Client) { + c.SetRetryAfter(RespectRetryAfter) + c.SetRetryCount(DefaultRetryCount) } -func checkRetryConditionals(c *Client) func(*resty.Response, error) bool { - return func(r *resty.Response, err error) bool { - for _, retryConditional := range c.retryConditionals { - retry := retryConditional(r, err) - if retry { - log.Printf("[INFO] Received error %s - Retrying", r.Error()) - return true - } - } - return false +// nolint:unused +func RespectRetryAfter(resp *http.Response) (time.Duration, error) { + retryAfterStr := resp.Header.Get(RetryAfterHeaderName) + if retryAfterStr == "" { + return 0, nil } + + retryAfter, err := strconv.Atoi(retryAfterStr) + if err != nil { + return 0, err + } + + duration := time.Duration(retryAfter) * time.Second + log.Printf("[INFO] Respecting Retry-After Header of %d (%s)", retryAfter, duration) + return duration, nil } -// SetLinodeBusyRetry configures resty to retry specifically on "Linode busy." errors -// The retry wait time is configured in SetPollDelay -func linodeBusyRetryCondition(r *resty.Response, _ error) bool { - apiError, ok := r.Error().(*APIError) +// Retry conditions + +// nolint:unused +func LinodeBusyRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + apiError, ok := getAPIError(resp) linodeBusy := ok && apiError.Error() == "Linode busy." - retry := r.StatusCode() == http.StatusBadRequest && linodeBusy + retry := resp.StatusCode == http.StatusBadRequest && linodeBusy return retry } -func tooManyRequestsRetryCondition(r *resty.Response, _ error) bool { - return r.StatusCode() == http.StatusTooManyRequests +// nolint:unused +func TooManyRequestsRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + return resp.StatusCode == http.StatusTooManyRequests } -func serviceUnavailableRetryCondition(r *resty.Response, _ error) bool { - serviceUnavailable := r.StatusCode() == http.StatusServiceUnavailable +// nolint:unused +func ServiceUnavailableRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + serviceUnavailable := resp.StatusCode == http.StatusServiceUnavailable // During maintenance events, the API will return a 503 and add // an `X-MAINTENANCE-MODE` header. Don't retry during maintenance // events, only for legitimate 503s. - if serviceUnavailable && r.Header().Get(maintenanceModeHeaderName) != "" { + if serviceUnavailable && resp.Header.Get(MaintenanceModeHeaderName) != "" { log.Printf("[INFO] Linode API is under maintenance, request will not be retried - please see status.linode.com for more information") return false } @@ -74,32 +96,38 @@ func serviceUnavailableRetryCondition(r *resty.Response, _ error) bool { return serviceUnavailable } -func requestTimeoutRetryCondition(r *resty.Response, _ error) bool { - return r.StatusCode() == http.StatusRequestTimeout -} +// nolint:unused +func RequestTimeoutRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } -func requestGOAWAYRetryCondition(_ *resty.Response, e error) bool { - return errors.As(e, &http2.GoAwayError{}) + return resp.StatusCode == http.StatusRequestTimeout } -func requestNGINXRetryCondition(r *resty.Response, _ error) bool { - return r.StatusCode() == http.StatusBadRequest && - r.Header().Get("Server") == "nginx" && - r.Header().Get("Content-Type") == "text/html" +// nolint:unused +func RequestGOAWAYRetryCondition(_ *http.Response, err error) bool { + return errors.As(err, &http2.GoAwayError{}) } -func respectRetryAfter(client *resty.Client, resp *resty.Response) (time.Duration, error) { - retryAfterStr := resp.Header().Get(retryAfterHeaderName) - if retryAfterStr == "" { - return 0, nil +// nolint:unused +func RequestNGINXRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false } - retryAfter, err := strconv.Atoi(retryAfterStr) + return resp.StatusCode == http.StatusBadRequest && + resp.Header.Get("Server") == "nginx" && + resp.Header.Get("Content-Type") == "text/html" +} + +// Helper function to extract APIError from response +// nolint:unused +func getAPIError(resp *http.Response) (*APIError, bool) { + var apiError APIError + err := json.NewDecoder(resp.Body).Decode(&apiError) if err != nil { - return 0, err + return nil, false } - - duration := time.Duration(retryAfter) * time.Second - log.Printf("[INFO] Respecting Retry-After Header of %d (%s) (max %s)", retryAfter, duration, client.RetryMaxWaitTime) - return duration, nil + return &apiError, true } diff --git a/retries_http.go b/retries_http.go deleted file mode 100644 index 0439af48..00000000 --- a/retries_http.go +++ /dev/null @@ -1,127 +0,0 @@ -package linodego - -import ( - "encoding/json" - "errors" - "log" - "net/http" - "strconv" - "time" - - "golang.org/x/net/http2" -) - -const ( - // nolint:unused - httpRetryAfterHeaderName = "Retry-After" - // nolint:unused - httpMaintenanceModeHeaderName = "X-Maintenance-Mode" - - // nolint:unused - httpDefaultRetryCount = 1000 -) - -// RetryConditional is a type alias for a function that determines if a request should be retried based on the response and error. -// nolint:unused -type httpRetryConditional func(*http.Response, error) bool - -// RetryAfter is a type alias for a function that determines the duration to wait before retrying based on the response. -// nolint:unused -type httpRetryAfter func(*http.Response) (time.Duration, error) - -// Configures http.Client to lock until enough time has passed to retry the request as determined by the Retry-After response header. -// If the Retry-After header is not set, we fall back to the value of SetPollDelay. -// nolint:unused -func httpConfigureRetries(c *httpClient) { - c.retryConditionals = append(c.retryConditionals, httpcheckRetryConditionals(c)) - c.retryAfter = httpRespectRetryAfter -} - -// nolint:unused -func httpcheckRetryConditionals(c *httpClient) httpRetryConditional { - return func(resp *http.Response, err error) bool { - for _, retryConditional := range c.retryConditionals { - retry := retryConditional(resp, err) - if retry { - log.Printf("[INFO] Received error %v - Retrying", err) - return true - } - } - return false - } -} - -// nolint:unused -func httpRespectRetryAfter(resp *http.Response) (time.Duration, error) { - retryAfterStr := resp.Header.Get(retryAfterHeaderName) - if retryAfterStr == "" { - return 0, nil - } - - retryAfter, err := strconv.Atoi(retryAfterStr) - if err != nil { - return 0, err - } - - duration := time.Duration(retryAfter) * time.Second - log.Printf("[INFO] Respecting Retry-After Header of %d (%s)", retryAfter, duration) - return duration, nil -} - -// Retry conditions - -// nolint:unused -func httpLinodeBusyRetryCondition(resp *http.Response, _ error) bool { - apiError, ok := getAPIError(resp) - linodeBusy := ok && apiError.Error() == "Linode busy." - retry := resp.StatusCode == http.StatusBadRequest && linodeBusy - return retry -} - -// nolint:unused -func httpTooManyRequestsRetryCondition(resp *http.Response, _ error) bool { - return resp.StatusCode == http.StatusTooManyRequests -} - -// nolint:unused -func httpServiceUnavailableRetryCondition(resp *http.Response, _ error) bool { - serviceUnavailable := resp.StatusCode == http.StatusServiceUnavailable - - // During maintenance events, the API will return a 503 and add - // an `X-MAINTENANCE-MODE` header. Don't retry during maintenance - // events, only for legitimate 503s. - if serviceUnavailable && resp.Header.Get(maintenanceModeHeaderName) != "" { - log.Printf("[INFO] Linode API is under maintenance, request will not be retried - please see status.linode.com for more information") - return false - } - - return serviceUnavailable -} - -// nolint:unused -func httpRequestTimeoutRetryCondition(resp *http.Response, _ error) bool { - return resp.StatusCode == http.StatusRequestTimeout -} - -// nolint:unused -func httpRequestGOAWAYRetryCondition(_ *http.Response, err error) bool { - return errors.As(err, &http2.GoAwayError{}) -} - -// nolint:unused -func httpRequestNGINXRetryCondition(resp *http.Response, _ error) bool { - return resp.StatusCode == http.StatusBadRequest && - resp.Header.Get("Server") == "nginx" && - resp.Header.Get("Content-Type") == "text/html" -} - -// Helper function to extract APIError from response -// nolint:unused -func getAPIError(resp *http.Response) (*APIError, bool) { - var apiError APIError - err := json.NewDecoder(resp.Body).Decode(&apiError) - if err != nil { - return nil, false - } - return &apiError, true -} diff --git a/retries_http_test.go b/retries_http_test.go deleted file mode 100644 index 35eea5fc..00000000 --- a/retries_http_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package linodego - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "testing" - "time" -) - -func TestHTTPLinodeBusyRetryCondition(t *testing.T) { - var retry bool - - // Initialize response body - rawResponse := &http.Response{ - StatusCode: http.StatusBadRequest, - Body: io.NopCloser(bytes.NewBuffer(nil)), - } - - retry = httpLinodeBusyRetryCondition(rawResponse, nil) - - if retry { - t.Errorf("Should not have retried") - } - - apiError := APIError{ - Errors: []APIErrorReason{ - {Reason: "Linode busy."}, - }, - } - rawResponse.Body = createResponseBody(apiError) - - retry = httpLinodeBusyRetryCondition(rawResponse, nil) - - if !retry { - t.Errorf("Should have retried") - } -} - -func TestHTTPServiceUnavailableRetryCondition(t *testing.T) { - rawResponse := &http.Response{ - StatusCode: http.StatusServiceUnavailable, - Header: http.Header{httpRetryAfterHeaderName: []string{"20"}}, - Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body - } - - if retry := httpServiceUnavailableRetryCondition(rawResponse, nil); !retry { - t.Error("expected request to be retried") - } - - if retryAfter, err := httpRespectRetryAfter(rawResponse); err != nil { - t.Errorf("expected error to be nil but got %s", err) - } else if retryAfter != time.Second*20 { - t.Errorf("expected retryAfter to be 20 but got %d", retryAfter) - } -} - -func TestHTTPServiceMaintenanceModeRetryCondition(t *testing.T) { - rawResponse := &http.Response{ - StatusCode: http.StatusServiceUnavailable, - Header: http.Header{ - httpRetryAfterHeaderName: []string{"20"}, - httpMaintenanceModeHeaderName: []string{"Currently in maintenance mode."}, - }, - Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body - } - - if retry := httpServiceUnavailableRetryCondition(rawResponse, nil); retry { - t.Error("expected retry to be skipped due to maintenance mode header") - } -} - -// Helper function to create a response body from an object -func createResponseBody(obj interface{}) io.ReadCloser { - body, err := json.Marshal(obj) - if err != nil { - panic(err) - } - return io.NopCloser(bytes.NewBuffer(body)) -} diff --git a/retries_test.go b/retries_test.go index 4f002938..45b5fc4d 100644 --- a/retries_test.go +++ b/retries_test.go @@ -1,24 +1,24 @@ package linodego import ( + "bytes" + "encoding/json" + "io" "net/http" "testing" "time" - - "github.com/go-resty/resty/v2" ) func TestLinodeBusyRetryCondition(t *testing.T) { var retry bool - request := resty.Request{} - rawResponse := http.Response{StatusCode: http.StatusBadRequest} - response := resty.Response{ - Request: &request, - RawResponse: &rawResponse, + // Initialize response body + rawResponse := &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(bytes.NewBuffer(nil)), } - retry = linodeBusyRetryCondition(&response, nil) + retry = LinodeBusyRetryCondition(rawResponse, nil) if retry { t.Errorf("Should not have retried") @@ -29,48 +29,53 @@ func TestLinodeBusyRetryCondition(t *testing.T) { {Reason: "Linode busy."}, }, } - request.SetError(&apiError) + rawResponse.Body = createResponseBody(apiError) - retry = linodeBusyRetryCondition(&response, nil) + retry = LinodeBusyRetryCondition(rawResponse, nil) if !retry { t.Errorf("Should have retried") } } -func TestLinodeServiceUnavailableRetryCondition(t *testing.T) { - request := resty.Request{} - rawResponse := http.Response{StatusCode: http.StatusServiceUnavailable, Header: http.Header{ - retryAfterHeaderName: []string{"20"}, - }} - response := resty.Response{ - Request: &request, - RawResponse: &rawResponse, +func TestServiceUnavailableRetryCondition(t *testing.T) { + rawResponse := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{RetryAfterHeaderName: []string{"20"}}, + Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body } - if retry := serviceUnavailableRetryCondition(&response, nil); !retry { + if retry := ServiceUnavailableRetryCondition(rawResponse, nil); !retry { t.Error("expected request to be retried") } - if retryAfter, err := respectRetryAfter(NewClient(nil).resty, &response); err != nil { + if retryAfter, err := RespectRetryAfter(rawResponse); err != nil { t.Errorf("expected error to be nil but got %s", err) } else if retryAfter != time.Second*20 { t.Errorf("expected retryAfter to be 20 but got %d", retryAfter) } } -func TestLinodeServiceMaintenanceModeRetryCondition(t *testing.T) { - request := resty.Request{} - rawResponse := http.Response{StatusCode: http.StatusServiceUnavailable, Header: http.Header{ - retryAfterHeaderName: []string{"20"}, - maintenanceModeHeaderName: []string{"Currently in maintenance mode."}, - }} - response := resty.Response{ - Request: &request, - RawResponse: &rawResponse, +func TestServiceMaintenanceModeRetryCondition(t *testing.T) { + rawResponse := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{ + RetryAfterHeaderName: []string{"20"}, + MaintenanceModeHeaderName: []string{"Currently in maintenance mode."}, + }, + Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body } - if retry := serviceUnavailableRetryCondition(&response, nil); retry { + if retry := ServiceUnavailableRetryCondition(rawResponse, nil); retry { t.Error("expected retry to be skipped due to maintenance mode header") } } + +// Helper function to create a response body from an object +func createResponseBody(obj interface{}) io.ReadCloser { + body, err := json.Marshal(obj) + if err != nil { + panic(err) + } + return io.NopCloser(bytes.NewBuffer(body)) +}