diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index b5ee7b37..f1906baf 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -40,7 +40,7 @@ var counter atomic.Int64 // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router { +func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) @@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router registerDeviceAuth(storage, r) }) + handler := http.Handler(provider) + if wrapServer { + handler = op.NewLegacyServer(provider, *op.DefaultEndpoints) + } + // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // then you would have to set the path prefix (/custom/path/) - router.Mount("/", provider) + router.Mount("/", handler) return router } diff --git a/example/server/main.go b/example/server/main.go index a1cc4618..38057fb7 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -27,7 +27,7 @@ func main() { Level: slog.LevelDebug, }), ) - router := exampleop.SetupServer(issuer, storage, logger) + router := exampleop.SetupServer(issuer, storage, logger, false) server := &http.Server{ Addr: ":" + port, diff --git a/example/server/storage/client.go b/example/server/storage/client.go index a3e7cc45..f512a992 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client { authMethod: oidc.AuthMethodBasic, loginURL: defaultLoginURL, responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, - grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, + grantTypes: oidc.AllGrantTypes, accessTokenType: op.AccessTokenTypeBearer, devMode: false, idTokenUserinfoClaimsAssertion: false, diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 7cbb62e6..1d3559e3 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -3,6 +3,7 @@ package client_test import ( "bytes" "context" + "fmt" "io" "math/rand" "net/http" @@ -50,6 +51,14 @@ func TestMain(m *testing.M) { } func TestRelyingPartySession(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testRelyingPartySession(t, wrapServer) + }) + } +} + +func testRelyingPartySession(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) @@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) { } func TestResourceServerTokenExchange(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testResourceServerTokenExchange(t, wrapServer) + }) + } +} + +func testResourceServerTokenExchange(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7610248e..20b1bf4c 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -74,7 +74,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { - authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) + err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) return @@ -130,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A // ParseRequestObject parse the `request` parameter, validates the token including the signature // and copies the token claims into the auth request -func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) { +func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error { requestObject := new(oidc.RequestObject) payload, err := oidc.ParseToken(authReq.RequestParam, requestObject) if err != nil { - return nil, err + return err } if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.Issuer != requestObject.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if !str.Contains(requestObject.Audience, issuer) { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { - return authReq, err + return err } CopyRequestObjectToAuthRequest(authReq, requestObject) - return authReq, nil + return nil } // CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request diff --git a/pkg/op/client.go b/pkg/op/client.go index d01845f2..04ef3c71 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au } return data.ClientID, false, nil } + +type ClientCredentials struct { + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body + ClientAssertion string `schema:"client_assertion"` // JWT + ClientAssertionType string `schema:"client_assertion_type"` +} diff --git a/pkg/op/config.go b/pkg/op/config.go index c40ed39e..f61412a8 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -20,14 +20,14 @@ var ( type Configuration interface { IssuerFromRequest(r *http.Request) string Insecure() bool - AuthorizationEndpoint() Endpoint - TokenEndpoint() Endpoint - IntrospectionEndpoint() Endpoint - UserinfoEndpoint() Endpoint - RevocationEndpoint() Endpoint - EndSessionEndpoint() Endpoint - KeysEndpoint() Endpoint - DeviceAuthorizationEndpoint() Endpoint + AuthorizationEndpoint() *Endpoint + TokenEndpoint() *Endpoint + IntrospectionEndpoint() *Endpoint + UserinfoEndpoint() *Endpoint + RevocationEndpoint() *Endpoint + EndSessionEndpoint() *Endpoint + KeysEndpoint() *Endpoint + DeviceAuthorizationEndpoint() *Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool diff --git a/pkg/op/device.go b/pkg/op/device.go index 029bed8a..55d3c572 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -63,41 +63,51 @@ func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt } func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { - storage, err := assertDeviceStorage(o.Storage()) + req, err := ParseDeviceCodeRequest(r, o) if err != nil { return err } - - req, err := ParseDeviceCodeRequest(r, o) + response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o) if err != nil { return err } + httphelper.MarshalJSON(w, response) + return nil +} + +func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) { + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { + return nil, err + } config := o.DeviceAuthorization() deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } expires := time.Now().Add(config.Lifetime) - err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) + err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } var verification *url.URL if config.UserFormURL != "" { if verification, err = url.Parse(config.UserFormURL); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + return nil, NewStatusError(err, http.StatusInternalServerError) } } else { - if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil { + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + return nil, NewStatusError(err, http.StatusInternalServerError) } verification.Path = config.UserFormPath } @@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide verification.RawQuery = "user_code=" + userCode response.VerificationURIComplete = verification.String() - - httphelper.MarshalJSON(w, response) - return nil + return response, nil } func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 782a279e..82512615 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{ func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - Discover(w, CreateDiscoveryConfig(r, c, s)) + Discover(w, CreateDiscoveryConfig(r.Context(), c, s)) } } @@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { httphelper.MarshalJSON(w, config) } -func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { - issuer := config.IssuerFromRequest(r) +func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) return &oidc.DiscoveryConfiguration{ Issuer: issuer, AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer), @@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), SubjectTypesSupported: SubjectTypes(config), - IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), + RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), + TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), + TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), + IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config), + IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config), + RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config), + RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config), + ClaimsSupported: SupportedClaims(config), + CodeChallengeMethodsSupported: CodeChallengeMethods(config), + UILocalesSupported: config.SupportedUILocales(), + RequestParameterSupported: config.RequestObjectSupported(), + } +} + +func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) + return &oidc.DiscoveryConfiguration{ + Issuer: issuer, + AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer), + TokenEndpoint: endpoints.Token.Absolute(issuer), + IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer), + UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer), + RevocationEndpoint: endpoints.Revocation.Absolute(issuer), + EndSessionEndpoint: endpoints.EndSession.Absolute(issuer), + JwksURI: endpoints.JwksURI.Absolute(issuer), + DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer), + ScopesSupported: Scopes(config), + ResponseTypesSupported: ResponseTypes(config), + GrantTypesSupported: GrantTypes(config), + SubjectTypesSupported: SubjectTypes(config), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 3e95ec35..84e12165 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) { type args struct { - request *http.Request - c op.Configuration - s op.DiscoverStorage + ctx context.Context + c op.Configuration + s op.DiscoverStorage } tests := []struct { name string @@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s) + got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s) assert.Equal(t, tt.want, got) }) } diff --git a/pkg/op/endpoint.go b/pkg/op/endpoint.go index b1e15073..1ac1cad7 100644 --- a/pkg/op/endpoint.go +++ b/pkg/op/endpoint.go @@ -1,32 +1,46 @@ package op -import "strings" +import ( + "errors" + "strings" +) type Endpoint struct { path string url string } -func NewEndpoint(path string) Endpoint { - return Endpoint{path: path} +func NewEndpoint(path string) *Endpoint { + return &Endpoint{path: path} } -func NewEndpointWithURL(path, url string) Endpoint { - return Endpoint{path: path, url: url} +func NewEndpointWithURL(path, url string) *Endpoint { + return &Endpoint{path: path, url: url} } -func (e Endpoint) Relative() string { +func (e *Endpoint) Relative() string { + if e == nil { + return "" + } return relativeEndpoint(e.path) } -func (e Endpoint) Absolute(host string) string { +func (e *Endpoint) Absolute(host string) string { + if e == nil { + return "" + } if e.url != "" { return e.url } return absoluteEndpoint(host, e.path) } -func (e Endpoint) Validate() error { +var ErrNilEndpoint = errors.New("nil endpoint") + +func (e *Endpoint) Validate() error { + if e == nil { + return ErrNilEndpoint + } return nil // TODO: } diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go index 46e5d478..bf112eff 100644 --- a/pkg/op/endpoint_test.go +++ b/pkg/op/endpoint_test.go @@ -3,13 +3,14 @@ package op_test import ( "testing" + "github.com/stretchr/testify/require" "github.com/zitadel/oidc/v3/pkg/op" ) func TestEndpoint_Path(t *testing.T) { tests := []struct { name string - e op.Endpoint + e *op.Endpoint want string }{ { @@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) { op.NewEndpointWithURL("/test", "http://test.com/test"), "/test", }, + { + "nil", + nil, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) { } tests := []struct { name string - e op.Endpoint + e *op.Endpoint args args want string }{ @@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) { args{"https://host"}, "https://test.com/test", }, + { + "nil", + nil, + args{"https://host"}, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) { func TestEndpoint_Validate(t *testing.T) { tests := []struct { name string - e op.Endpoint - wantErr bool + e *op.Endpoint + wantErr error }{ - // TODO: Add test cases. + { + "nil", + nil, + op.ErrNilEndpoint, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.e.Validate(); (err != nil) != tt.wantErr { - t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) - } + err := tt.e.Validate() + require.ErrorIs(t, err, tt.wantErr) }) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index 9981fecc..0cac14b9 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -1,6 +1,9 @@ package op import ( + "context" + "errors" + "fmt" "net/http" httphelper "github.com/zitadel/oidc/v3/pkg/http" @@ -66,3 +69,101 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) } + +// TryErrorRedirect tries to handle an error by redirecting a client. +// If this attempt fails, an error is returned that must be returned +// to the client instead. +func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) { + e := oidc.DefaultToServerError(parent, parent.Error()) + logger = logger.With("oidc_error", e) + + if authReq == nil { + logger.Log(ctx, e.LogLevel(), "auth request") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(ctx, e.LogLevel(), "auth request: not redirecting") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + e.State = authReq.GetState() + var responseMode oidc.ResponseMode + if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { + responseMode = rm.GetResponseMode() + } + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + if err != nil { + logger.ErrorContext(ctx, "auth response URL", "error", err) + return nil, AsStatusError(err, http.StatusBadRequest) + } + logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url) + return NewRedirect(url), nil +} + +// StatusError wraps an error with a HTTP status code. +// The status code is passed to the handler's writer. +type StatusError struct { + parent error + statusCode int +} + +// NewStatusError sets the parent and statusCode to a new StatusError. +// It is recommended for parent to be an [oidc.Error]. +// +// Typically implementations should only use this to signal something +// very specific, like an internal server error. +// If a returned error is not a StatusError, the framework +// will set a statusCode based on what the standard specifies, +// which is [http.StatusBadRequest] for most of the time. +// If the error encountered can described clearly with a [oidc.Error], +// do not use this function, as it might break standard rules! +func NewStatusError(parent error, statusCode int) StatusError { + return StatusError{ + parent: parent, + statusCode: statusCode, + } +} + +// AsStatusError unwraps a StatusError from err +// and returns it unmodified if found. +// If no StatuError was found, a new one is returned +// with statusCode set to it as a default. +func AsStatusError(err error, statusCode int) (target StatusError) { + if errors.As(err, &target) { + return target + } + return NewStatusError(err, statusCode) +} + +func (e StatusError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error()) +} + +func (e StatusError) Unwrap() error { + return e.parent +} + +func (e StatusError) Is(err error) bool { + var target StatusError + if !errors.As(err, &target) { + return false + } + return errors.Is(e.parent, target.parent) && + e.statusCode == target.statusCode +} + +// WriteError asserts for a StatusError containing an [oidc.Error]. +// If no StatusError is found, the status code will default to [http.StatusBadRequest]. +// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError]. +func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { + statusError := AsStatusError(err, http.StatusBadRequest) + e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()) + + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) + httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode) +} diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go index dc5ef110..689ee5ab 100644 --- a/pkg/op/error_test.go +++ b/pkg/op/error_test.go @@ -1,9 +1,12 @@ package op import ( + "context" + "fmt" "io" "net/http" "net/http/httptest" + "net/url" "strings" "testing" @@ -275,3 +278,400 @@ func TestRequestError(t *testing.T) { }) } } + +func TestTryErrorRedirect(t *testing.T) { + type args struct { + ctx context.Context + authReq ErrAuthRequest + parent error + } + tests := []struct { + name string + args args + want *Redirect + wantErr error + wantLog string + }{ + { + name: "nil auth request", + args: args{ + ctx: context.Background(), + authReq: nil, + parent: io.ErrClosedPipe, + }, + wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest), + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: func() error { + //lint:ignore SA1007 just recreating the error for testing + _, err := url.Parse("can't parse this!\n") + err = oidc.ErrServerError().WithParent(err) + return NewStatusError(err, http.StatusBadRequest) + }(), + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + want: &Redirect{ + URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1", + }, + wantLog: `{ + "level":"WARN", + "msg":"auth request redirect", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + }, + "url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + encoder := schema.NewEncoder() + + got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestNewStatusError(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + + want := "Internal Server Error: io: read/write on closed pipe" + got := fmt.Sprint(err) + assert.Equal(t, want, got) +} + +func TestAsStatusError(t *testing.T) { + type args struct { + err error + statusCode int + } + tests := []struct { + name string + args args + want string + }{ + { + name: "already status error", + args: args{ + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + statusCode: http.StatusBadRequest, + }, + want: "Internal Server Error: io: read/write on closed pipe", + }, + { + name: "oidc error", + args: args{ + err: oidc.ErrAcrInvalid, + statusCode: http.StatusBadRequest, + }, + want: "Bad Request: acr is invalid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := AsStatusError(tt.args.err, tt.args.statusCode) + got := fmt.Sprint(err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestStatusError_Unwrap(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestStatusError_Is(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "nil error", + args: args{err: nil}, + want: false, + }, + { + name: "other error", + args: args{err: io.EOF}, + want: false, + }, + { + name: "other parent", + args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)}, + want: false, + }, + { + name: "other status", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)}, + want: false, + }, + { + name: "same", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)}, + want: true, + }, + { + name: "wrapped", + args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + if got := e.Is(tt.args.err); got != tt.want { + t.Errorf("StatusError.Is() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteError(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + wantBody string + wantLog string + }{ + { + name: "not a status or oidc error", + err: io.ErrClosedPipe, + wantStatus: http.StatusBadRequest, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "time":"not" + }`, + }, + { + name: "status error w/o oidc", + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "time":"not" + }`, + }, + { + name: "oidc error w/o status", + err: oidc.ErrInvalidRequest().WithDescription("oops"), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "error":"invalid_request", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"invalid_request" + }, + "time":"not" + }`, + }, + { + name: "status with oidc error", + err: NewStatusError( + oidc.ErrUnauthorizedClient().WithDescription("oops"), + http.StatusUnauthorized, + ), + wantStatus: http.StatusUnauthorized, + wantBody: `{ + "error":"unauthorized_client", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"unauthorized_client" + }, + "time":"not" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + r := httptest.NewRequest("GET", "/target", nil) + w := httptest.NewRecorder() + + WriteError(w, r, tt.err, logger) + res := w.Result() + assert.Equal(t, tt.wantStatus, res.StatusCode, "status code") + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, tt.wantBody, string(gotBody), "body") + assert.JSONEq(t, tt.wantLog, logOut.String()) + }) + } +} diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index 96429ddb..f392a455 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom } // AuthorizationEndpoint mocks base method. -func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call { } // DeviceAuthorizationEndpoint mocks base method. -func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C } // EndSessionEndpoint mocks base method. -func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { +func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EndSessionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup } // IntrospectionEndpoint mocks base method. -func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { +func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IntrospectionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go } // KeysEndpoint mocks base method. -func (m *MockConfiguration) KeysEndpoint() op.Endpoint { +func (m *MockConfiguration) KeysEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeysEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor } // RevocationEndpoint mocks base method. -func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { +func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RevocationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call { } // TokenEndpoint mocks base method. -func (m *MockConfiguration) TokenEndpoint() op.Endpoint { +func (m *MockConfiguration) TokenEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TokenEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported } // UserinfoEndpoint mocks base method. -func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { +func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UserinfoEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } diff --git a/pkg/op/op.go b/pkg/op/op.go index 0175d7fe..55ee9866 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -32,7 +32,7 @@ const ( ) var ( - DefaultEndpoints = &endpoints{ + DefaultEndpoints = &Endpoints{ Authorization: NewEndpoint(defaultAuthorizationEndpoint), Token: NewEndpoint(defaultTokenEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint), @@ -131,16 +131,17 @@ type Config struct { DeviceAuthorization DeviceAuthorizationConfig } -type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint - DeviceAuthorization Endpoint +// Endpoints defines endpoint routes. +type Endpoints struct { + Authorization *Endpoint + Token *Endpoint + Introspection *Endpoint + Userinfo *Endpoint + Revocation *Endpoint + EndSession *Endpoint + CheckSessionIframe *Endpoint + JwksURI *Endpoint + DeviceAuthorization *Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -212,7 +213,7 @@ type Provider struct { config *Config issuer IssuerFromRequest insecure bool - endpoints *endpoints + endpoints *Endpoints storage Storage keySet *openIDKeySet crypto Crypto @@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool { return o.insecure } -func (o *Provider) AuthorizationEndpoint() Endpoint { +func (o *Provider) AuthorizationEndpoint() *Endpoint { return o.endpoints.Authorization } -func (o *Provider) TokenEndpoint() Endpoint { +func (o *Provider) TokenEndpoint() *Endpoint { return o.endpoints.Token } -func (o *Provider) IntrospectionEndpoint() Endpoint { +func (o *Provider) IntrospectionEndpoint() *Endpoint { return o.endpoints.Introspection } -func (o *Provider) UserinfoEndpoint() Endpoint { +func (o *Provider) UserinfoEndpoint() *Endpoint { return o.endpoints.Userinfo } -func (o *Provider) RevocationEndpoint() Endpoint { +func (o *Provider) RevocationEndpoint() *Endpoint { return o.endpoints.Revocation } -func (o *Provider) EndSessionEndpoint() Endpoint { +func (o *Provider) EndSessionEndpoint() *Endpoint { return o.endpoints.EndSession } -func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { +func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint { return o.endpoints.DeviceAuthorization } -func (o *Provider) KeysEndpoint() Endpoint { +func (o *Provider) KeysEndpoint() *Endpoint { return o.endpoints.JwksURI } @@ -420,7 +421,7 @@ func WithAllowInsecure() Option { } } -func WithCustomAuthEndpoint(endpoint Endpoint) Option { +func WithCustomAuthEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option { } } -func WithCustomTokenEndpoint(endpoint Endpoint) Option { +func WithCustomTokenEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option { } } -func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { +func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { } } -func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { +func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { } } -func WithCustomRevocationEndpoint(endpoint Endpoint) Option { +func WithCustomRevocationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { +func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { } } -func WithCustomKeysEndpoint(endpoint Endpoint) Option { +func WithCustomKeysEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } -func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { +func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { +// WithCustomEndpoints sets multiple endpoints at once. +// Non of the endpoints may be nil, or an error will +// be returned when the Option used by the Provider. +func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option { return func(o *Provider) error { + for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} { + if err := e.Validate(); err != nil { + return err + } + } o.endpoints.Authorization = auth o.endpoints.Token = token o.endpoints.Userinfo = userInfo diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index d33b39d5..abe53bc1 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) { }) } } + +func TestWithCustomEndpoints(t *testing.T) { + type args struct { + auth *op.Endpoint + token *op.Endpoint + userInfo *op.Endpoint + revocation *op.Endpoint + endSession *op.Endpoint + keys *op.Endpoint + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "all nil", + args: args{}, + wantErr: op.ErrNilEndpoint, + }, + { + name: "all set", + args: args{ + auth: op.NewEndpoint("/authorize"), + token: op.NewEndpoint("/oauth/token"), + userInfo: op.NewEndpoint("/userinfo"), + revocation: op.NewEndpoint("/revoke"), + endSession: op.NewEndpoint("/end_session"), + keys: op.NewEndpoint("/keys"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := op.NewOpenIDProvider(testIssuer, testConfig, + storage.NewStorage(storage.NewUserStore(testIssuer)), + op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys), + ) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErr != nil { + return + } + assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint()) + assert.Equal(t, tt.args.token, provider.TokenEndpoint()) + assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint()) + assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint()) + assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint()) + assert.Equal(t, tt.args.keys, provider.KeysEndpoint()) + }) + } +} diff --git a/pkg/op/probes.go b/pkg/op/probes.go index 9ef5bb56..cb3853d8 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn { } func ok(w http.ResponseWriter) { - httphelper.MarshalJSON(w, status{"ok"}) + httphelper.MarshalJSON(w, Status{"ok"}) } -type status struct { +type Status struct { Status string `json:"status,omitempty"` } diff --git a/pkg/op/server.go b/pkg/op/server.go new file mode 100644 index 00000000..a9cdcf5f --- /dev/null +++ b/pkg/op/server.go @@ -0,0 +1,346 @@ +package op + +import ( + "context" + "net/http" + "net/url" + + "github.com/muhlemmer/gu" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// Server describes the interface that needs to be implemented to serve +// OpenID Connect and Oauth2 standard requests. +// +// Methods are called after the HTTP route is resolved and +// the request body is parsed into the Request's Data field. +// When a method is called, it can be assumed that required fields, +// as described in their relevant standard, are validated already. +// The Response Data field may be of any type to allow flexibility +// to extend responses with custom fields. There are however requirements +// in the standards regarding the response models. Where applicable +// the method documentation gives a recommended type which can be used +// directly or extended upon. +// +// The addition of new methods is not considered a breaking change +// as defined by semver rules. +// Implementations MUST embed [UnimplementedServer] to maintain +// forward compatibility. +// +// EXPERIMENTAL: may change until v4 +type Server interface { + // Health returns a status of "ok" once the Server is listening. + // The recommended Response Data type is [Status]. + Health(context.Context, *Request[struct{}]) (*Response, error) + + // Ready returns a status of "ok" once all dependencies, + // such as database storage, are ready. + // An error can be returned to explain what is not ready. + // The recommended Response Data type is [Status]. + Ready(context.Context, *Request[struct{}]) (*Response, error) + + // Discovery returns the OpenID Provider Configuration Information for this server. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig + // The recommended Response Data type is [oidc.DiscoveryConfiguration]. + Discovery(context.Context, *Request[struct{}]) (*Response, error) + + // Keys serves the JWK set which the client can use verify signatures from the op. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key. + // The recommended Response Data type is [jose.JSONWebKeySet]. + Keys(context.Context, *Request[struct{}]) (*Response, error) + + // VerifyAuthRequest verifies the Auth Request and + // adds the Client to the request. + // + // When the `request` field is populated with a + // "Request Object" JWT, it needs to be Validated + // and its claims overwrite any fields in the AuthRequest. + // If the implementation does not support "Request Object", + // it MUST return an [oidc.ErrRequestNotSupported]. + // https://openid.net/specs/openid-connect-core-1_0.html#RequestObject + VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) + + // Authorize initiates the authorization flow and redirects to a login page. + // See the various https://openid.net/specs/openid-connect-core-1_0.html + // authorize endpoint sections (one for each type of flow). + Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error) + + // DeviceAuthorization initiates the device authorization flow. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 + // The recommended Response Data type is [oidc.DeviceAuthorizationResponse]. + DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) + + // VerifyClient is called on most oauth/token handlers to authenticate, + // using either a secret (POST, Basic) or assertion (JWT). + // If no secrets are provided, the client must be public. + // This method is called before each method that takes a + // [ClientRequest] argument. + VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error) + + // CodeExchange returns Tokens after an authorization code + // is obtained in a successful Authorize flow. + // It is called by the Token endpoint handler when + // grant_type has the value authorization_code + // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + // The recommended Response Data type is [oidc.AccessTokenResponse]. + CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) + + // RefreshToken returns new Tokens after verifying a Refresh token. + // It is called by the Token endpoint handler when + // grant_type has the value refresh_token + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens + // The recommended Response Data type is [oidc.AccessTokenResponse]. + RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) + + // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer + // https://datatracker.ietf.org/doc/html/rfc7523#section-2.1 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error) + + // TokenExchange handles the OAuth 2.0 token exchange grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange + // https://datatracker.ietf.org/doc/html/rfc8693 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) + + // ClientCredentialsExchange handles the OAuth 2.0 client credentials grant + // It is called by the Token endpoint handler when + // grant_type has the value client_credentials + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) + + // DeviceToken handles the OAuth 2.0 Device Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:device_code. + // It is typically called in a polling fashion and appropriate errors + // should be returned to signal authorization_pending or access_denied etc. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4, + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5. + // The recommended Response Data type is [oidc.AccessTokenResponse]. + DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) + + // Introspect handles the OAuth 2.0 Token Introspection endpoint. + // https://datatracker.ietf.org/doc/html/rfc7662 + // The recommended Response Data type is [oidc.IntrospectionResponse]. + Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) + + // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User. + // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo + // The recommended Response Data type is [oidc.UserInfo]. + UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error) + + // Revocation handles token revocation using an access or refresh token. + // https://datatracker.ietf.org/doc/html/rfc7009 + // There are no response requirements. Data may remain empty. + Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error) + + // EndSession handles the OpenID Connect RP-Initiated Logout. + // https://openid.net/specs/openid-connect-rpinitiated-1_0.html + // There are no response requirements. Data may remain empty. + EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error) + + // mustImpl forces implementations to embed the UnimplementedServer for forward + // compatibility with the interface. + mustImpl() +} + +// Request contains the [http.Request] informational fields +// and parsed Data from the request body (POST) or URL parameters (GET). +// Data can be assumed to be validated according to the applicable +// standard for the specific endpoints. +// +// EXPERIMENTAL: may change until v4 +type Request[T any] struct { + Method string + URL *url.URL + Header http.Header + Form url.Values + PostForm url.Values + Data *T +} + +func (r *Request[_]) path() string { + return r.URL.Path +} + +func newRequest[T any](r *http.Request, data *T) *Request[T] { + return &Request[T]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + PostForm: r.PostForm, + Data: data, + } +} + +// ClientRequest is a Request with a verified client attached to it. +// Methods that receive this argument may assume the client was authenticated, +// or verified to be a public client. +// +// EXPERIMENTAL: may change until v4 +type ClientRequest[T any] struct { + *Request[T] + Client Client +} + +func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] { + return &ClientRequest[T]{ + Request: newRequest[T](r, data), + Client: client, + } +} + +// Response object for most [Server] methods. +// +// EXPERIMENTAL: may change until v4 +type Response struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + // Data will be JSON marshaled to + // the response body. + // We allow any type, so that implementations + // can extend the standard types as they wish. + // However, each method will recommend which + // (base) type to use as model, in order to + // be compliant with the standards. + Data any +} + +// NewResponse creates a new response for data, +// without custom headers. +func NewResponse(data any) *Response { + return &Response{ + Data: data, + } +} + +func (resp *Response) writeOut(w http.ResponseWriter) { + gu.MapMerge(resp.Header, w.Header()) + httphelper.MarshalJSON(w, resp.Data) +} + +// Redirect is a special response type which will +// initiate a [http.StatusFound] redirect. +// The Params field will be encoded and set to the +// URL's RawQuery field before building the URL. +// +// EXPERIMENTAL: may change until v4 +type Redirect struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + URL string +} + +func NewRedirect(url string) *Redirect { + return &Redirect{URL: url} +} + +func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) { + gu.MapMerge(r.Header, w.Header()) + http.Redirect(w, r, red.URL, http.StatusFound) +} + +type UnimplementedServer struct{} + +// UnimplementedStatusCode is the status code returned for methods +// that are not yet implemented. +// Note that this means methods in the sense of the Go interface, +// and not http methods covered by "501 Not Implemented". +var UnimplementedStatusCode = http.StatusNotFound + +func unimplementedError(r interface{ path() string }) StatusError { + err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path()) + return NewStatusError(err, UnimplementedStatusCode) +} + +func unimplementedGrantError(gt oidc.GrantType) StatusError { + err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt) + return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 +} + +func (UnimplementedServer) mustImpl() {} + +func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + return nil, oidc.ErrRequestNotSupported() + } + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeCode) +} + +func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) +} + +func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) +} + +func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) +} + +func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) +} + +func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) +} + +func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go new file mode 100644 index 00000000..3fb481d1 --- /dev/null +++ b/pkg/op/server_http.go @@ -0,0 +1,480 @@ +package op + +import ( + "context" + "net/http" + "net/url" + + "github.com/go-chi/chi" + "github.com/rs/cors" + "github.com/zitadel/logging" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +// RegisterServer registers an implementation of Server. +// The resulting handler takes care of routing and request parsing, +// with some basic validation of required fields. +// The routes can be customized with [WithEndpoints]. +// +// EXPERIMENTAL: may change until v4 +func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + + ws := &webServer{ + server: server, + endpoints: endpoints, + decoder: decoder, + logger: slog.Default(), + } + + for _, option := range options { + option(ws) + } + + ws.createRouter() + return ws +} + +type ServerOption func(s *webServer) + +// WithHTTPMiddleware sets the passed middleware chain to the root of +// the Server's router. +func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { + return func(s *webServer) { + s.middleware = m + } +} + +// WithDecoder overrides the default decoder, +// which is a [schema.Decoder] with IgnoreUnknownKeys set to true. +func WithDecoder(decoder httphelper.Decoder) ServerOption { + return func(s *webServer) { + s.decoder = decoder + } +} + +// WithFallbackLogger overrides the fallback logger, which +// is used when no logger was found in the context. +// Defaults to [slog.Default]. +func WithFallbackLogger(logger *slog.Logger) ServerOption { + return func(s *webServer) { + s.logger = logger + } +} + +type webServer struct { + http.Handler + server Server + middleware []func(http.Handler) http.Handler + endpoints Endpoints + decoder httphelper.Decoder + logger *slog.Logger +} + +func (s *webServer) getLogger(ctx context.Context) *slog.Logger { + if logger, ok := logging.FromContext(ctx); ok { + return logger + } + return s.logger +} + +func (s *webServer) createRouter() { + router := chi.NewRouter() + router.Use(cors.New(defaultCORSOptions).Handler) + router.Use(s.middleware...) + router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) + router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) + router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) + + s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler) + s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) + s.endpointRoute(router, s.endpoints.Token, s.tokensHandler) + s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler)) + s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler) + s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler)) + s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler) + s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) + s.Handler = router +} + +func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) { + if e != nil { + router.HandleFunc(e.Relative(), hf) + s.logger.Info("registered route", "endpoint", e.Relative()) + } +} + +type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) + +func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { + if !ValidateGrantType(client, grantType) { + WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context())) + return + } + } + handler(w, r, client) + } +} + +func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { + if err = r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + cc := new(ClientCredentials) + if err = s.decoder.Decode(cc, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + // Basic auth takes precedence, so if set it overwrites the form data. + if clientID, clientSecret, ok := r.BasicAuth(); ok { + cc.ClientID, err = url.QueryUnescape(clientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + cc.ClientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + } + if cc.ClientID == "" && cc.ClientAssertion == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided") + } + if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion { + return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType) + } + return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + Data: cc, + }) +} + +func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect, err := s.authorize(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect.writeOut(w, r) +} + +func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { + cr, err := s.server.VerifyAuthRequest(ctx, r) + if err != nil { + return nil, err + } + authReq := cr.Data + if authReq.RedirectURI == "" { + return nil, ErrAuthReqMissingRedirectURI + } + authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) + if err != nil { + return nil, err + } + authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes) + if err != nil { + return nil, err + } + if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil { + return nil, err + } + if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil { + return nil, err + } + return s.server.Authorize(ctx, cr) +} + +func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + + switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType { + case oidc.GrantTypeCode: + s.withClient(s.codeExchangeHandler)(w, r) + case oidc.GrantTypeRefreshToken: + s.withClient(s.refreshTokenHandler)(w, r) + case oidc.GrantTypeClientCredentials: + s.withClient(s.clientCredentialsHandler)(w, r) + case oidc.GrantTypeBearer: + s.jwtProfileHandler(w, r) + case oidc.GrantTypeTokenExchange: + s.withClient(s.tokenExchangeHandler)(w, r) + case oidc.GrantTypeDeviceCode: + s.withClient(s.deviceTokenHandler)(w, r) + case "": + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context())) + default: + WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context())) + } +} + +func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Assertion == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Code == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context())) + return + } + if request.RedirectURI == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.RefreshToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.SubjectToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context())) + return + } + if request.SubjectTokenType == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context())) + return + } + if !request.SubjectTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context())) + return + } + resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + + request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.DeviceCode == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if token, err := getAccessToken(r); err == nil { + request.AccessToken = token + } + if request.AccessToken == "" { + err = NewStatusError( + oidc.ErrInvalidRequest().WithDescription("access token missing"), + http.StatusUnauthorized, + ) + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.UserInfo(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.EndSession(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w, r) +} + +func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + resp, err := method(r.Context(), newRequest(r, &struct{}{})) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) + } +} + +func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) { + dst := new(R) + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + form := r.Form + if postOnly { + form = r.PostForm + } + if err := decoder.Decode(dst, form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + return dst, nil +} diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go new file mode 100644 index 00000000..6a8b75d6 --- /dev/null +++ b/pkg/op/server_http_routes_test.go @@ -0,0 +1,345 @@ +package op_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +func jwtProfile() (string, error) { + keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json") + if err != nil { + return "", err + } + signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID) + if err != nil { + return "", err + } + return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer) +} + +func TestServerRoutes(t *testing.T) { + server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints) + + storage := testProvider.Storage().(routesTestStorage) + ctx := op.ContextWithIssuer(context.Background(), testIssuer) + + client, err := storage.GetClientByClientID(ctx, "web") + require.NoError(t, err) + + oidcAuthReq := &oidc.AuthRequest{ + ClientID: client.GetID(), + RedirectURI: "https://example.com", + MaxAge: gu.Ptr[uint](300), + Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone}, + ResponseType: oidc.ResponseTypeCode, + } + + authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1") + require.NoError(t, err) + storage.AuthRequestDone(authReq.GetID()) + + accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client) + require.NoError(t, err) + jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "") + require.NoError(t, err) + jwtProfileToken, err := jwtProfile() + require.NoError(t, err) + + oidcAuthReq.IDTokenHint = idToken + + serverURL, err := url.Parse(testIssuer) + require.NoError(t, err) + + type basicAuth struct { + username, password string + } + + tests := []struct { + name string + method string + path string + basicAuth *basicAuth + header map[string]string + values map[string]string + body map[string]string + wantCode int + headerContains map[string]string + json string // test for exact json output + contains []string // when the body output is not constant, we just check for snippets to be present in the response + }{ + { + name: "health", + method: http.MethodGet, + path: "/healthz", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "ready", + method: http.MethodGet, + path: "/ready", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "discovery", + method: http.MethodGet, + path: oidc.DiscoveryEndpoint, + wantCode: http.StatusOK, + json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`, + }, + { + name: "authorization", + method: http.MethodGet, + path: testProvider.AuthorizationEndpoint().Relative(), + values: map[string]string{ + "client_id": client.GetID(), + "redirect_uri": "https://example.com", + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "response_type": string(oidc.ResponseTypeCode), + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/login/username?authRequestID="}, + }, + { + // This call will fail. A successfull test is already + // part of client/integration_test.go + name: "code exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeCode), + "client_id": client.GetID(), + "client_secret": "secret", + "redirect_uri": "https://example.com", + "code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_grant", "error_description":"invalid code"}`, + }, + { + name: "JWT authorization", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeBearer), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "assertion": jwtProfileToken, + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`}, + }, + { + name: "Token exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeTokenExchange), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "subject_token": jwtToken, + "subject_token_type": string(oidc.AccessTokenType), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`, + }, + }, + { + name: "Client credentials exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"sid1", "verysecret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeClientCredentials), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, + }, + { + // This call will fail. A successfull test is already + // part of device_test.go + name: "device token", + method: http.MethodPost, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + header: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + body: map[string]string{ + "grant_type": string(oidc.GrantTypeDeviceCode), + "device_code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"access_denied","error_description":"The authorization request was denied."}`, + }, + { + name: "missing grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_request","error_description":"grant_type missing"}`, + }, + { + name: "unsupported grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": "foo", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`, + }, + { + name: "introspection", + method: http.MethodGet, + path: testProvider.IntrospectionEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessToken, + }, + wantCode: http.StatusOK, + json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "user info", + method: http.MethodGet, + path: testProvider.UserinfoEndpoint().Relative(), + header: map[string]string{ + "authorization": "Bearer " + accessToken, + }, + wantCode: http.StatusOK, + json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "refresh token", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeRefreshToken), + "refresh_token": refreshToken, + "client_id": client.GetID(), + "client_secret": "secret", + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","token_type":"Bearer","refresh_token":"`, + `","expires_in":299,"id_token":"`, + }, + }, + { + name: "revoke", + method: http.MethodGet, + path: testProvider.RevocationEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessTokenRevoke, + }, + wantCode: http.StatusOK, + }, + { + name: "end session", + method: http.MethodGet, + path: testProvider.EndSessionEndpoint().Relative(), + values: map[string]string{ + "id_token_hint": idToken, + "client_id": "web", + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/logged-out"}, + contains: []string{`Found.`}, + }, + { + name: "keys", + method: http.MethodGet, + path: testProvider.KeysEndpoint().Relative(), + wantCode: http.StatusOK, + contains: []string{ + `{"keys":[{"use":"sig","kty":"RSA","kid":"`, + `","alg":"RS256","n":"`, `","e":"AQAB"}]}`, + }, + }, + { + name: "device authorization", + method: http.MethodGet, + path: testProvider.DeviceAuthorizationEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"device_code":"`, `","user_code":"`, + `","verification_uri":"https://localhost:9998/device"`, + `"verification_uri_complete":"https://localhost:9998/device?user_code=`, + `","expires_in":300,"interval":5}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := gu.PtrCopy(serverURL) + u.Path = tt.path + if tt.values != nil { + u.RawQuery = mapAsValues(tt.values) + } + var body io.Reader + if tt.body != nil { + body = strings.NewReader(mapAsValues(tt.body)) + } + + req := httptest.NewRequest(tt.method, u.String(), body) + for k, v := range tt.header { + req.Header.Set(k, v) + } + if tt.basicAuth != nil { + req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password) + } + + rec := httptest.NewRecorder() + server.ServeHTTP(rec, req) + + resp := rec.Result() + require.NoError(t, err) + assert.Equal(t, tt.wantCode, resp.StatusCode) + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respBodyString := string(respBody) + t.Log(respBodyString) + t.Log(resp.Header) + + if tt.json != "" { + assert.JSONEq(t, tt.json, respBodyString) + } + for _, c := range tt.contains { + assert.Contains(t, respBodyString, c) + } + for k, v := range tt.headerContains { + assert.Contains(t, resp.Header.Get(k), v) + } + }) + } +} diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go new file mode 100644 index 00000000..86fe7ed8 --- /dev/null +++ b/pkg/op/server_http_test.go @@ -0,0 +1,1333 @@ +package op + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +func TestRegisterServer(t *testing.T) { + server := UnimplementedServer{} + endpoints := Endpoints{ + Authorization: &Endpoint{ + path: "/auth", + }, + } + decoder := schema.NewDecoder() + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + + h := RegisterServer(server, endpoints, + WithDecoder(decoder), + WithFallbackLogger(logger), + ) + got := h.(*webServer) + assert.Equal(t, got.server, server) + assert.Equal(t, got.endpoints, endpoints) + assert.Equal(t, got.decoder, decoder) + assert.Equal(t, got.logger, logger) +} + +type testClient struct { + id string + appType ApplicationType + authMethod oidc.AuthMethod + accessTokenType AccessTokenType + responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType + devMode bool +} + +type clientType string + +const ( + clientTypeWeb clientType = "web" + clientTypeNative clientType = "native" + clientTypeUserAgent clientType = "useragent" +) + +func newClient(kind clientType) *testClient { + client := &testClient{ + id: string(kind), + } + + switch kind { + case clientTypeWeb: + client.appType = ApplicationTypeWeb + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeNative: + client.appType = ApplicationTypeNative + client.authMethod = oidc.AuthMethodNone + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeUserAgent: + client.appType = ApplicationTypeUserAgent + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeJWT + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken} + default: + panic(fmt.Errorf("invalid client type %s", kind)) + } + return client +} + +func (c *testClient) RedirectURIs() []string { + return []string{ + "https://registered.com/callback", + "http://registered.com/callback", + "http://localhost:9999/callback", + "custom://callback", + } +} + +func (c *testClient) PostLogoutRedirectURIs() []string { + return []string{} +} + +func (c *testClient) LoginURL(id string) string { + return "login?id=" + id +} + +func (c *testClient) ApplicationType() ApplicationType { + return c.appType +} + +func (c *testClient) AuthMethod() oidc.AuthMethod { + return c.authMethod +} + +func (c *testClient) GetID() string { + return c.id +} + +func (c *testClient) AccessTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) IDTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) AccessTokenType() AccessTokenType { + return c.accessTokenType +} + +func (c *testClient) ResponseTypes() []oidc.ResponseType { + return c.responseTypes +} + +func (c *testClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} + +func (c *testClient) DevMode() bool { + return c.devMode +} + +func (c *testClient) AllowedScopes() []string { + return nil +} + +func (c *testClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) IsScopeAllowed(scope string) bool { + return false +} + +func (c *testClient) IDTokenUserinfoClaimsAssertion() bool { + return false +} + +func (c *testClient) ClockSkew() time.Duration { + return 0 +} + +type requestVerifier struct { + UnimplementedServer + client Client +} + +func (s *requestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: s.client, + }, nil +} + +func (s *requestVerifier) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return s.client, nil +} + +var testDecoder = func() *schema.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + return decoder +}() + +type webServerResult struct { + wantStatus int + wantBody string +} + +func runWebServerTest(t *testing.T, handler http.HandlerFunc, r *http.Request, want webServerResult) { + t.Helper() + if r.Method == http.MethodPost { + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + w := httptest.NewRecorder() + handler(w, r) + res := w.Result() + assert.Equal(t, want.wantStatus, res.StatusCode) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, want.wantBody, string(body)) +} + +func Test_webServer_withClient(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&grant_type=bad&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unauthorized_client", "error_description":"grant_type \"bad\" not allowed"}`, + }, + }, + { + name: "no grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusOK, + wantBody: `{"foo":"bar"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + logger: slog.Default(), + } + handler := func(w http.ResponseWriter, r *http.Request, client Client) { + fmt.Fprintf(w, `{"foo":%q}`, r.FormValue("foo")) + } + runWebServerTest(t, s.withClient(handler), tt.r, tt.want) + }) + } +} + +func Test_webServer_verifyRequestClient(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want Client + wantErr error + }{ + { + name: "parse form error", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "basic auth, client_id error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth(`%%%`, "secret") + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "basic auth, client_secret error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth("web", `%%%`) + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "missing client id and assertion", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided"), + }, + { + name: "wrong assertion type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")), + wantErr: oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type wrong"), + }, + { + name: "unimplemented verify client called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_id=web")), + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/ not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + tt.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := s.verifyRequestClient(tt.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_authorizeHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "authorize error", + fields: fields{ + server: &requestVerifier{}, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.authorizeHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_authorize(t *testing.T) { + type args struct { + ctx context.Context + r *Request[oidc.AuthRequest] + } + tests := []struct { + name string + server Server + args args + want *Redirect + wantErr error + }{ + { + name: "verify error", + server: &requestVerifier{}, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: oidc.ErrServerError(), + }, + { + name: "missing redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: ErrAuthReqMissingRedirectURI, + }, + { + name: "invalid prompt", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone, oidc.PromptLogin}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value"), + }, + { + name: "missing scopes", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest(). + WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://example.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequestRedirectURI(). + WithDescription("The requested redirect_uri is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid response type", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeIDToken, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "unimplemented Authorize called", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + URL: &url.URL{ + Path: "/authorize", + }, + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/authorize not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.server, + decoder: testDecoder, + logger: slog.Default(), + } + got, err := s.authorize(tt.args.ctx, tt.args.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_deviceAuthorizationHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented DeviceAuthorization called", + fields: fields{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native_client")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceAuthorizationHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokensHandler(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse form error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "missing grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"grant_type missing"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("grant_type=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"bar not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + logger: slog.Default(), + } + runWebServerTest(t, s.tokensHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_jwtProfileHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "assertion missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"assertion missing"}`, + }, + }, + { + name: "unimplemented JWTProfile called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("assertion=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:jwt-bearer not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.jwtProfileHandler, tt.r, tt.want) + }) + } +} + +func runWebServerClientTest(t *testing.T, handler func(http.ResponseWriter, *http.Request, Client), r *http.Request, client Client, want webServerResult) { + t.Helper() + runWebServerTest(t, func(client Client) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, client) + } + }(client), r, want) +} + +func Test_webServer_codeExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"code missing"}`, + }, + }, + { + name: "redirect missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"redirect_uri missing"}`, + }, + }, + { + name: "unimplemented CodeExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123&redirect_uri=https://example.com/callback")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"authorization_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.codeExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_refreshTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "refresh token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"refresh_token missing"}`, + }, + }, + { + name: "unimplemented RefreshToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("refresh_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"refresh_token not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.refreshTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokenExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "subject token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token missing"}`, + }, + }, + { + name: "subject token type missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type missing"}`, + }, + }, + { + name: "subject token type unsupported", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type is not supported"}`, + }, + }, + { + name: "unsupported requested token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"requested_token_type is not supported"}`, + }, + }, + { + name: "unsupported actor token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"actor_token_type is not supported"}`, + }, + }, + { + name: "unimplemented TokenExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=urn:ietf:params:oauth:token-type:access_token")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:token-exchange not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.tokenExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_clientCredentialsHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "unimplemented ClientCredentialsExchange called", + decoder: testDecoder, + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"client_credentials not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.clientCredentialsHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_deviceTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "device code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"device_code missing"}`, + }, + }, + { + name: "unimplemented DeviceToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("device_code=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:device_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_introspectionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Introspect called", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.introspectionHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_userInfoHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "access token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusUnauthorized, + wantBody: `{"error":"invalid_request", "error_description":"access token missing"}`, + }, + }, + { + name: "unimplemented UserInfo called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("access_token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "bearer", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("authorization", strings.Join([]string{"Bearer", "xxx"}, " ")) + return r + }(), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.userInfoHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_revocationHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Revocation called, confidential client", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "unimplemented Revocation called, public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.revocationHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_endSessionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented EndSession called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("id_token_hint=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.endSessionHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_simpleHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + method func(context.Context, *Request[struct{}]) (*Response, error) + r *http.Request + want webServerResult + }{ + { + name: "parse error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "method error", + decoder: schema.NewDecoder(), + method: func(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, io.ErrClosedPipe + }, + r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, simpleHandler(s, tt.method), tt.r, tt.want) + }) + } +} + +func Test_decodeRequest(t *testing.T) { + type dst struct { + A string `schema:"a"` + B string `schema:"b"` + } + type args struct { + r *http.Request + postOnly bool + } + tests := []struct { + name string + args args + want *dst + wantErr error + }{ + { + name: "parse error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decode error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "success, get", + args: args{ + r: httptest.NewRequest(http.MethodGet, "/?a=b&b=a", nil), + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + { + name: "success, post only", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: true, + }, + want: &dst{ + A: "b", + }, + }, + { + name: "success, post mixed", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: false, + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.r.Method == http.MethodPost { + tt.args.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + got, err := decodeRequest[dst](schema.NewDecoder(), tt.args.r, tt.args.postOnly) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go new file mode 100644 index 00000000..0a7de855 --- /dev/null +++ b/pkg/op/server_legacy.go @@ -0,0 +1,344 @@ +package op + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/go-chi/chi" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// LegacyServer is an implementation of [Server[] that +// simply wraps a [OpenIDProvider]. +// It can be used to transition from the former Provider/Storage +// interfaces to the new Server interface. +type LegacyServer struct { + UnimplementedServer + provider OpenIDProvider + endpoints Endpoints +} + +// NewLegacyServer wraps provider in a `Server` and returns a handler which is +// the Server's router. +// +// Only non-nil endpoints will be registered on the router. +// Nil endpoints are disabled. +// +// The passed endpoints is also set to the provider, +// to be consistent with the discovery config. +// Any `With*Endpoint()` option used on the provider is +// therefore ineffective. +func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler { + server := RegisterServer(&LegacyServer{ + provider: provider, + endpoints: endpoints, + }, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) + + router := chi.NewRouter() + router.Mount("/", server) + router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) + + return router +} + +func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + for _, probe := range s.provider.Probes() { + // shouldn't we run probes in Go routines? + if err := probe(ctx); err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + } + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse( + createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints), + ), nil +} + +func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + keys, err := s.provider.Storage().KeySet(ctx) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + return NewResponse(jsonWebKeySet(keys)), nil +} + +var ( + ErrAuthReqMissingClientID = errors.New("auth request is missing client_id") + ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri") +) + +func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + if !s.provider.RequestObjectSupported() { + return nil, oidc.ErrRequestNotSupported() + } + err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx)) + if err != nil { + return nil, err + } + } + if r.Data.ClientID == "" { + return nil, ErrAuthReqMissingClientID + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id") + } + + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: client, + }, nil +} + +func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) { + userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx)) + if err != nil { + return nil, err + } + req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID) + if err != nil { + return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger()) + } + return NewRedirect(r.Client.LoginURL(req.GetID())), nil +} + +func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + return NewResponse(response), nil +} + +func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported") + } + return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret) + } + + if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { + jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") + } + return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger) + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithParent(err) + } + + switch client.AuthMethod() { + case oidc.AuthMethodNone: + return client, nil + case oidc.AuthMethodPrivateKeyJWT: + return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") + case oidc.AuthMethodPost: + if !s.provider.AuthMethodPostSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") + } + } + + err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage()) + if err != nil { + return nil, err + } + + return client, nil +} + +func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code) + if err != nil { + return nil, err + } + if r.Client.AuthMethod() == oidc.AuthMethodNone { + if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil { + return nil, err + } + } + resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "") + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeRefreshTokenSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) + } + request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken) + if err != nil { + return nil, err + } + if r.Client.GetID() != request.GetClientID() { + return nil, oidc.ErrInvalidGrant() + } + if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil { + return nil, err + } + resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) + } + tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) + if err != nil { + return nil, err + } + + tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + if !s.provider.GrantTypeTokenExchangeSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider) + if err != nil { + return nil, err + } + resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) + } + tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeClientCredentialsSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) + } + // use a limited context timeout shorter as the default + // poll interval of 5 seconds. + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + + state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider) + if err != nil { + return nil, err + } + tokenRequest := &deviceAccessTokenRequest{ + subject: state.Subject, + audience: []string{r.Client.GetID()}, + scopes: state.Scopes, + } + resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + response := new(oidc.IntrospectionResponse) + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token) + if !ok { + return NewResponse(response), nil + } + err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID()) + if err != nil { + return NewResponse(response), nil + } + response.Active = true + return NewResponse(response), nil +} + +func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken) + if !ok { + return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized) + } + info := new(oidc.UserInfo) + err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin")) + if err != nil { + return nil, NewStatusError(err, http.StatusForbidden) + } + return NewResponse(info), nil +} + +func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + var subject string + doDecrypt := true + if r.Data.TokenTypeHint != "access_token" { + userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token) + if err != nil { + // An invalid refresh token means that we'll try other things (leaving doDecrypt==true) + if !errors.Is(err, ErrInvalidRefreshToken) { + return nil, RevocationError(oidc.ErrServerError().WithParent(err)) + } + } else { + r.Data.Token = tokenID + subject = userID + doDecrypt = false + } + } + if doDecrypt { + tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token) + if ok { + r.Data.Token = tokenID + subject = userID + } + } + if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil { + return nil, RevocationError(err) + } + return NewResponse(nil), nil +} + +func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider) + if err != nil { + return nil, err + } + err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID) + if err != nil { + return nil, err + } + return NewRedirect(session.RedirectURI), nil +} diff --git a/pkg/op/server_test.go b/pkg/op/server_test.go new file mode 100644 index 00000000..0cad8fd5 --- /dev/null +++ b/pkg/op/server_test.go @@ -0,0 +1,5 @@ +package op + +// implementation check +var _ Server = &UnimplementedServer{} +var _ Server = &LegacyServer{} diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index baf377bc..371e1d41 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -88,7 +88,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if err != nil { return nil, nil, err } - err = AuthorizeCodeChallenge(tokenReq, request.GetCodeChallenge()) + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) return request, client, err } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 21db1347..5156741d 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -197,12 +197,6 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing") } - storage := exchanger.Storage() - teStorage, ok := storage.(TokenExchangeStorage) - if !ok { - return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported") - } - client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger) if err != nil { return nil, nil, err @@ -220,10 +214,28 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported") } + req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger) + if err != nil { + return nil, nil, err + } + return req, client, nil +} + +func CreateTokenExchangeRequest( + ctx context.Context, + oidcTokenExchangeRequest *oidc.TokenExchangeRequest, + client Client, + exchanger Exchanger, +) (TokenExchangeRequest, error) { + teStorage, ok := exchanger.Storage().(TokenExchangeStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") } var ( @@ -234,7 +246,7 @@ func ValidateTokenExchangeRequest( exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") } } @@ -258,17 +270,17 @@ func ValidateTokenExchangeRequest( authTime: time.Now(), } - err = teStorage.ValidateTokenExchangeRequest(ctx, req) + err := teStorage.ValidateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } err = teStorage.CreateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } - return req, client, nil + return req, nil } func GetTokenIDAndSubjectFromToken( diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 0df2fcee..b810633c 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -117,11 +117,11 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, // AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent // code_challenge of the auth request (PKCE) -func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { - if tokenReq.CodeVerifier == "" { +func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error { + if codeVerifier == "" { return oidc.ErrInvalidRequest().WithDescription("code_challenge required") } - if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { + if !oidc.VerifyCodeChallenge(challenge, codeVerifier) { return oidc.ErrInvalidGrant().WithDescription("invalid code challenge") } return nil diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index fd1ee931..d19c7f7f 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -131,6 +131,11 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token } func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { + statusErr := RevocationError(err) + httphelper.MarshalJSONWithStatus(w, statusErr.parent, statusErr.statusCode) +} + +func RevocationError(err error) StatusError { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest switch e.ErrorType { @@ -139,7 +144,7 @@ func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { case oidc.ServerError: status = 500 } - httphelper.MarshalJSONWithStatus(w, e, status) + return NewStatusError(e, status) } func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {