Skip to content

Commit

Permalink
Support http digest authentication.
Browse files Browse the repository at this point in the history
* Add SetDigestAuth for Request.
* Add SetCommonDigestAuth for Client.
* Add OnAfterResponse for Request
  • Loading branch information
imroc committed Jun 17, 2023
1 parent 767a6b9 commit ef6c1ad
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 39 deletions.
63 changes: 45 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,57 +767,69 @@ func (c *Client) EnableAutoReadResponse() *Client {
return c
}

// SetAutoDecodeContentType set the content types that will be auto-detected and decode
// to utf-8 (e.g. "json", "xml", "html", "text").
// SetAutoDecodeContentType set the content types that will be auto-detected and decode to utf-8
// (e.g. "json", "xml", "html", "text").
func (c *Client) SetAutoDecodeContentType(contentTypes ...string) *Client {
c.t.SetAutoDecodeContentType(contentTypes...)
return c
}

// SetAutoDecodeContentTypeFunc set the function that determines whether the
// specified `Content-Type` should be auto-detected and decode to utf-8.
// SetAutoDecodeContentTypeFunc set the function that determines whether the specified `Content-Type` should be auto-detected and decode to utf-8.
func (c *Client) SetAutoDecodeContentTypeFunc(fn func(contentType string) bool) *Client {
c.t.SetAutoDecodeContentTypeFunc(fn)
return c
}

// SetAutoDecodeAllContentType enable try auto-detect charset and decode all
// content type to utf-8.
// SetAutoDecodeAllContentType enable try auto-detect charset and decode all content type to utf-8.
func (c *Client) SetAutoDecodeAllContentType() *Client {
c.t.SetAutoDecodeAllContentType()
return c
}

// DisableAutoDecode disable auto-detect charset and decode to utf-8
// (enabled by default).
// DisableAutoDecode disable auto-detect charset and decode to utf-8 (enabled by default).
func (c *Client) DisableAutoDecode() *Client {
c.t.DisableAutoDecode()
return c
}

// EnableAutoDecode enable auto-detect charset and decode to utf-8
// (enabled by default).
// EnableAutoDecode enable auto-detect charset and decode to utf-8 (enabled by default).
func (c *Client) EnableAutoDecode() *Client {
c.t.EnableAutoDecode()
return c
}

// SetUserAgent set the "User-Agent" header for requests fired from
// the client.
// SetUserAgent set the "User-Agent" header for requests fired from the client.
func (c *Client) SetUserAgent(userAgent string) *Client {
return c.SetCommonHeader(header.UserAgent, userAgent)
}

// SetCommonBearerAuthToken set the bearer auth token for requests
// fired from the client.
// SetCommonBearerAuthToken set the bearer auth token for requests fired from the client.
func (c *Client) SetCommonBearerAuthToken(token string) *Client {
return c.SetCommonHeader("Authorization", "Bearer "+token)
return c.SetCommonHeader(header.Authorization, "Bearer "+token)
}

// SetCommonBasicAuth set the basic auth for requests fired from
// the client.
func (c *Client) SetCommonBasicAuth(username, password string) *Client {
c.SetCommonHeader("Authorization", util.BasicAuthHeaderValue(username, password))
c.SetCommonHeader(header.Authorization, util.BasicAuthHeaderValue(username, password))
return c
}

// SetCommonDigestAuth sets the Digest Access auth scheme for requests fired from the client. If a server responds with
// 401 and sends a Digest challenge in the WWW-Authenticate Header, requests will be resent with the appropriate
// Authorization Header.
//
// For Example: To set the Digest scheme with user "roc" and password "123456"
//
// client.SetCommonDigestAuth("roc", "123456")
//
// Information about Digest Access Authentication can be found in RFC7616:
//
// https://datatracker.ietf.org/doc/html/rfc7616
//
// See `Request.SetDigestAuth`
func (c *Client) SetCommonDigestAuth(username, password string) *Client {
c.OnAfterResponse(handleDigestAuthFunc(username, password))
return c
}

Expand Down Expand Up @@ -1500,6 +1512,21 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) {
err = resp.Err
}()

if r.Headers == nil {
r.Headers = make(http.Header)
}

for _, f := range r.client.udBeforeRequest {
if err = f(r.client, r); err != nil {
return
}
}
for _, f := range r.client.beforeRequest {
if err = f(r.client, r); err != nil {
return
}
}

// setup trace
if r.trace == nil && r.client.trace {
r.trace = &clientTrace{}
Expand Down Expand Up @@ -1581,8 +1608,8 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) {
resp.Body = io.NopCloser(bytes.NewReader(resp.body))
}

for _, f := range r.client.afterResponse {
if e := f(r.client, resp); e != nil {
for _, f := range c.afterResponse {
if e := f(c, resp); e != nil {
resp.Err = e
}
}
Expand Down
277 changes: 277 additions & 0 deletions digest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
package req

import (
"crypto/md5"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"github.com/imroc/req/v3/internal/header"
"hash"
"io"
"net/http"
"strings"
)

var (
errDigestBadChallenge = errors.New("digest: challenge is bad")
errDigestCharset = errors.New("digest: unsupported charset")
errDigestAlgNotSupported = errors.New("digest: algorithm is not supported")
errDigestQopNotSupported = errors.New("digest: no supported qop in list")
errDigestNoQop = errors.New("digest: qop must be specified")
)

var hashFuncs = map[string]func() hash.Hash{
"": md5.New,
"MD5": md5.New,
"MD5-sess": md5.New,
"SHA-256": sha256.New,
"SHA-256-sess": sha256.New,
"SHA-512-256": sha512.New,
"SHA-512-256-sess": sha512.New,
}

// create response middleware for http digest authentication.
func handleDigestAuthFunc(username, password string) ResponseMiddleware {
return func(client *Client, resp *Response) error {
if resp.Err != nil || resp.StatusCode != http.StatusUnauthorized {
return nil
}
auth, err := createDigestAuth(resp.Response, username, password)
if err != nil {
return err
}
r := resp.Request
req := *r.RawRequest
if req.Body != nil {
err = parseRequestBody(client, r) // re-setup body
if err != nil {
return err
}
if r.GetBody != nil {
body, err := r.GetBody()
if err != nil {
return err
}
req.Body = body
req.GetBody = r.GetBody
}
}
if req.Header == nil {
req.Header = make(http.Header)
}
req.Header.Set(header.Authorization, auth)
resp.Response, err = client.GetTransport().RoundTrip(&req)
return err
}
}

func createDigestAuth(resp *http.Response, username, password string) (auth string, err error) {
chal := resp.Header.Get(header.WwwAuthenticate)
if chal == "" {
return "", errDigestBadChallenge
}

c, err := parseChallenge(chal)
if err != nil {
return "", err
}

// Form credentials based on the challenge
cr := newCredentials(resp.Request.URL.RequestURI(), resp.Request.Method, username, password, c)
auth, err = cr.authorize()
return
}

func newCredentials(digestURI, method, username, password string, c *challenge) *credentials {
return &credentials{
username: username,
userhash: c.userhash,
realm: c.realm,
nonce: c.nonce,
digestURI: digestURI,
algorithm: c.algorithm,
sessionAlg: strings.HasSuffix(c.algorithm, "-sess"),
opaque: c.opaque,
messageQop: c.qop,
nc: 0,
method: method,
password: password,
}
}

type challenge struct {
realm string
domain string
nonce string
opaque string
stale string
algorithm string
qop string
userhash string
}

func parseChallenge(input string) (*challenge, error) {
const ws = " \n\r\t"
const qs = `"`
s := strings.Trim(input, ws)
if !strings.HasPrefix(s, "Digest ") {
return nil, errDigestBadChallenge
}
s = strings.Trim(s[7:], ws)
sl := strings.Split(s, ", ")
c := &challenge{}
var r []string
for i := range sl {
r = strings.SplitN(sl[i], "=", 2)
if len(r) != 2 {
return nil, errDigestBadChallenge
}
switch r[0] {
case "realm":
c.realm = strings.Trim(r[1], qs)
case "domain":
c.domain = strings.Trim(r[1], qs)
case "nonce":
c.nonce = strings.Trim(r[1], qs)
case "opaque":
c.opaque = strings.Trim(r[1], qs)
case "stale":
c.stale = r[1]
case "algorithm":
c.algorithm = r[1]
case "qop":
c.qop = strings.Trim(r[1], qs)
case "charset":
if strings.ToUpper(strings.Trim(r[1], qs)) != "UTF-8" {
return nil, errDigestCharset
}
case "userhash":
c.userhash = strings.Trim(r[1], qs)
default:
return nil, errDigestBadChallenge
}
}
return c, nil
}

type credentials struct {
username string
userhash string
realm string
nonce string
digestURI string
algorithm string
sessionAlg bool
cNonce string
opaque string
messageQop string
nc int
method string
password string
}

func (c *credentials) authorize() (string, error) {
if _, ok := hashFuncs[c.algorithm]; !ok {
return "", errDigestAlgNotSupported
}

if err := c.validateQop(); err != nil {
return "", err
}

resp, err := c.resp()
if err != nil {
return "", err
}

sl := make([]string, 0, 10)
if c.userhash == "true" {
// RFC 7616 3.4.4
c.username = c.h(fmt.Sprintf("%s:%s", c.username, c.realm))
sl = append(sl, fmt.Sprintf(`userhash=%s`, c.userhash))
}
sl = append(sl, fmt.Sprintf(`username="%s"`, c.username))
sl = append(sl, fmt.Sprintf(`realm="%s"`, c.realm))
sl = append(sl, fmt.Sprintf(`nonce="%s"`, c.nonce))
sl = append(sl, fmt.Sprintf(`uri="%s"`, c.digestURI))
sl = append(sl, fmt.Sprintf(`response="%s"`, resp))
sl = append(sl, fmt.Sprintf(`algorithm=%s`, c.algorithm))
if c.opaque != "" {
sl = append(sl, fmt.Sprintf(`opaque="%s"`, c.opaque))
}
if c.messageQop != "" {
sl = append(sl, fmt.Sprintf("qop=%s", c.messageQop))
sl = append(sl, fmt.Sprintf("nc=%08x", c.nc))
sl = append(sl, fmt.Sprintf(`cnonce="%s"`, c.cNonce))
}

return fmt.Sprintf("Digest %s", strings.Join(sl, ", ")), nil
}

func (c *credentials) validateQop() error {
// Currently only supporting auth quality of protection. TODO: add auth-int support
if c.messageQop == "" {
return errDigestNoQop
}
possibleQops := strings.Split(c.messageQop, ", ")
var authSupport bool
for _, qop := range possibleQops {
if qop == "auth" {
authSupport = true
break
}
}
if !authSupport {
return errDigestQopNotSupported
}

c.messageQop = "auth"

return nil
}

func (c *credentials) h(data string) string {
hfCtor := hashFuncs[c.algorithm]
hf := hfCtor()
_, _ = hf.Write([]byte(data)) // Hash.Write never returns an error
return fmt.Sprintf("%x", hf.Sum(nil))
}

func (c *credentials) resp() (string, error) {
c.nc++

b := make([]byte, 16)
_, err := io.ReadFull(rand.Reader, b)
if err != nil {
return "", err
}
c.cNonce = fmt.Sprintf("%x", b)[:32]

ha1 := c.ha1()
ha2 := c.ha2()

return c.kd(ha1, fmt.Sprintf("%s:%08x:%s:%s:%s",
c.nonce, c.nc, c.cNonce, c.messageQop, ha2)), nil
}

func (c *credentials) kd(secret, data string) string {
return c.h(fmt.Sprintf("%s:%s", secret, data))
}

// RFC 7616 3.4.2
func (c *credentials) ha1() string {
ret := c.h(fmt.Sprintf("%s:%s:%s", c.username, c.realm, c.password))
if c.sessionAlg {
return c.h(fmt.Sprintf("%s:%s:%s", ret, c.nonce, c.cNonce))
}

return ret
}

// RFC 7616 3.4.3
func (c *credentials) ha2() string {
// currently no auth-int support
return c.h(fmt.Sprintf("%s:%s", c.method, c.digestURI))
}
Loading

0 comments on commit ef6c1ad

Please sign in to comment.