From 7236cca8cf6c95bd383a84c188c58e96e05c10f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A2=D0=BE=D0=B2=D0=B0=D1=80=D0=B8=D1=89=20=D0=BF=D1=80?= =?UTF-8?q?=D0=BE=D0=B3=D1=80=D0=B0=D0=BC=D0=BC=D0=B8=D1=81=D1=82?= <46831212+ComradeProgrammer@users.noreply.github.com> Date: Mon, 11 Apr 2022 21:11:31 +0800 Subject: [PATCH] feat: implement CAS 3.0 (#659) --- controllers/cas.go | 83 ++++++++++++++++++++++++++--- object/saml_idp.go | 80 +++++++++++++++++++++++++++- object/token_cas.go | 109 +++++++++++++++++++++++++++++++++++---- routers/authz_filter.go | 2 +- routers/router.go | 8 ++- routers/static_filter.go | 2 +- 6 files changed, 262 insertions(+), 22 deletions(-) diff --git a/controllers/cas.go b/controllers/cas.go index 11bc744668a5..75595c568481 100644 --- a/controllers/cas.go +++ b/controllers/cas.go @@ -15,6 +15,7 @@ package controllers import ( + "encoding/xml" "fmt" "net/http" "net/url" @@ -42,7 +43,7 @@ func (c *RootController) CasValidate() { c.Ctx.Output.Body([]byte("no\n")) return } - if ok, response, issuedService := object.GetCasTokenByTicket(ticket); ok { + if ok, response, issuedService, _ := object.GetCasTokenByTicket(ticket); ok { //check whether service is the one for which we previously issued token if issuedService == service { c.Ctx.Output.Body([]byte(fmt.Sprintf("yes\n%s\n", response.User))) @@ -54,7 +55,25 @@ func (c *RootController) CasValidate() { c.Ctx.Output.Body([]byte("no\n")) } -func (c *RootController) CasServiceAndProxyValidate() { +func (c *RootController) CasServiceValidate() { + ticket := c.Input().Get("ticket") + format := c.Input().Get("format") + if !strings.HasPrefix(ticket, "ST") { + c.sendCasAuthenticationResponseErr(InvalidTicket, fmt.Sprintf("Ticket %s not recognized", ticket), format) + } + c.CasP3ServiceAndProxyValidate() +} + +func (c *RootController) CasProxyValidate() { + ticket := c.Input().Get("ticket") + format := c.Input().Get("format") + if !strings.HasPrefix(ticket, "PT") { + c.sendCasAuthenticationResponseErr(InvalidTicket, fmt.Sprintf("Ticket %s not recognized", ticket), format) + } + c.CasP3ServiceAndProxyValidate() +} + +func (c *RootController) CasP3ServiceAndProxyValidate() { ticket := c.Input().Get("ticket") format := c.Input().Get("format") service := c.Input().Get("service") @@ -69,10 +88,9 @@ func (c *RootController) CasServiceAndProxyValidate() { c.sendCasAuthenticationResponseErr(InvalidRequest, "service and ticket must exist", format) return } - + ok, response, issuedService, userId := object.GetCasTokenByTicket(ticket) //find the token - if ok, response, issuedService := object.GetCasTokenByTicket(ticket); ok { - + if ok { //check whether service is the one for which we previously issued token if strings.HasPrefix(service, issuedService) { serviceResponse.Success = response @@ -89,7 +107,7 @@ func (c *RootController) CasServiceAndProxyValidate() { if pgtUrl != "" && serviceResponse.Failure == nil { //that means we are in proxy web flow - pgt := object.StoreCasTokenForPgt(serviceResponse.Success, service) + pgt := object.StoreCasTokenForPgt(serviceResponse.Success, service, userId) pgtiou := serviceResponse.Success.ProxyGrantingTicket //todo: check whether it is https pgtUrlObj, err := url.Parse(pgtUrl) @@ -139,7 +157,7 @@ func (c *RootController) CasProxy() { return } - ok, authenticationSuccess, issuedService := object.GetCasTokenByPgt(pgt) + ok, authenticationSuccess, issuedService, userId := object.GetCasTokenByPgt(pgt) if !ok { c.sendCasProxyResponseErr(UnauthorizedService, "service not authorized", format) return @@ -150,7 +168,7 @@ func (c *RootController) CasProxy() { newAuthenticationSuccess.Proxies = &object.CasProxies{} } newAuthenticationSuccess.Proxies.Proxies = append(newAuthenticationSuccess.Proxies.Proxies, issuedService) - proxyTicket := object.StoreCasTokenForProxyTicket(&newAuthenticationSuccess, targetService) + proxyTicket := object.StoreCasTokenForProxyTicket(&newAuthenticationSuccess, targetService, userId) serviceResponse := object.CasServiceResponse{ Xmlns: "http://www.yale.edu/tp/cas", @@ -168,6 +186,55 @@ func (c *RootController) CasProxy() { } } + +func (c *RootController) SamlValidate() { + c.Ctx.Output.Header("Content-Type", "text/xml; charset=utf-8") + target := c.Input().Get("TARGET") + body := c.Ctx.Input.RequestBody + envelopRequest := struct { + XMLName xml.Name `xml:"Envelope"` + Body struct { + XMLName xml.Name `xml:"Body"` + Content string `xml:",innerxml"` + } + }{} + + err := xml.Unmarshal(body, &envelopRequest) + if err != nil { + c.ResponseError(err.Error()) + return + } + + response, service, err := object.GetValidationBySaml(envelopRequest.Body.Content, c.Ctx.Request.Host) + if err != nil { + c.ResponseError(err.Error()) + return + } + + if !strings.HasPrefix(target, service) { + c.ResponseError(fmt.Sprintf("service %s and %s do not match", target, service)) + return + } + + envelopReponse := struct { + XMLName xml.Name `xml:"SOAP-ENV:Envelope"` + Xmlns string `xml:"xmlns:SOAP-ENV"` + Body struct { + XMLName xml.Name `xml:"SOAP-ENV:Body"` + Content string `xml:",innerxml"` + } + }{} + envelopReponse.Xmlns = "http://schemas.xmlsoap.org/soap/envelope/" + envelopReponse.Body.Content = response + + data, err := xml.Marshal(envelopReponse) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.Ctx.Output.Body([]byte(data)) +} + func (c *RootController) sendCasProxyResponseErr(code, msg, format string) { serviceResponse := object.CasServiceResponse{ Xmlns: "http://www.yale.edu/tp/cas", diff --git a/object/saml_idp.go b/object/saml_idp.go index 01b4182ef43e..9b03a408f5d6 100644 --- a/object/saml_idp.go +++ b/object/saml_idp.go @@ -20,6 +20,7 @@ import ( "crypto" "crypto/rsa" "encoding/base64" + "encoding/json" "encoding/pem" "encoding/xml" "fmt" @@ -34,6 +35,7 @@ import ( uuid "github.com/satori/go.uuid" ) +//returns a saml2 response func NewSamlResponse(user *User, host string, publicKey string, destination string, iss string, redirectUri []string) (*etree.Element, error) { samlResponse := &etree.Element{ Space: "samlp", @@ -223,7 +225,8 @@ func GetSamlMeta(application *Application, host string) (*IdpEntityDescriptor, e return &d, nil } -//GenerateSamlResponse generates a SAML response +//GenerateSamlResponse generates a SAML2.0 response +//parameter samlRequest is saml request in base64 format func GetSamlResponse(application *Application, user *User, samlRequest string, host string) (string, string, error) { //decode samlRequest defated, err := base64.StdEncoding.DecodeString(samlRequest) @@ -272,3 +275,78 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h res := base64.StdEncoding.EncodeToString([]byte(xmlStr)) return res, authnRequest.AssertionConsumerServiceURL, nil } + +//return a saml1.1 response(not 2.0) +func NewSamlResponse11(user *User, requestID string, host string) *etree.Element { + samlResponse := &etree.Element{ + Space: "samlp", + Tag: "Response", + } + //create samlresponse + samlResponse.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:1.0:protocol") + samlResponse.CreateAttr("MajorVersion", "1") + samlResponse.CreateAttr("MinorVersion", "1") + + responseID := uuid.NewV4() + samlResponse.CreateAttr("ResponseID", fmt.Sprintf("_%s", responseID)) + samlResponse.CreateAttr("InResponseTo", requestID) + + now := time.Now().UTC().Format(time.RFC3339) + expireTime := time.Now().UTC().Add(time.Hour * 24).Format(time.RFC3339) + + samlResponse.CreateAttr("IssueInstant", now) + + samlResponse.CreateElement("samlp:Status").CreateElement("samlp:StatusCode").CreateAttr("Value", "samlp:Success") + + //create assertion which is inside the response + assertion := samlResponse.CreateElement("saml:Assertion") + assertion.CreateAttr("xmlns:saml", "urn:oasis:names:tc:SAML:1.0:assertion") + assertion.CreateAttr("MajorVersion", "1") + assertion.CreateAttr("MinorVersion", "1") + assertion.CreateAttr("AssertionID", uuid.NewV4().String()) + assertion.CreateAttr("Issuer", host) + assertion.CreateAttr("IssueInstant", now) + + condition := assertion.CreateElement("saml:Conditions") + condition.CreateAttr("NotBefore", now) + condition.CreateAttr("NotOnOrAfter", expireTime) + + //AuthenticationStatement inside assertion + authenticationStatement := assertion.CreateElement("saml:AuthenticationStatement") + authenticationStatement.CreateAttr("AuthenticationMethod", "urn:oasis:names:tc:SAML:1.0:am:password") + authenticationStatement.CreateAttr("AuthenticationInstant", now) + + //subject inside AuthenticationStatement + subject := assertion.CreateElement("saml:Subject") + //nameIdentifier inside subject + nameIdentifier := subject.CreateElement("saml:NameIdentifier") + //nameIdentifier.CreateAttr("Format", "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress") + nameIdentifier.SetText(user.Name) + + //subjectConfirmation inside subject + subjectConfirmation := subject.CreateElement("saml:SubjectConfirmation") + subjectConfirmation.CreateElement("saml:ConfirmationMethod").SetText("urn:oasis:names:tc:SAML:1.0:cm:artifact") + + attributeStatement := assertion.CreateElement("saml:AttributeStatement") + subjectInAttribute := attributeStatement.CreateElement("saml:Subject") + nameIdentifierInAttribute := subjectInAttribute.CreateElement("saml:NameIdentifier") + nameIdentifierInAttribute.SetText(user.Name) + + subjectConfirmationInAttribute := subjectInAttribute.CreateElement("saml:SubjectConfirmation") + subjectConfirmationInAttribute.CreateElement("saml:ConfirmationMethod").SetText("urn:oasis:names:tc:SAML:1.0:cm:artifact") + + data, _ := json.Marshal(user) + tmp := map[string]string{} + json.Unmarshal(data, &tmp) + + for k, v := range tmp { + if v != "" { + attr := attributeStatement.CreateElement("saml:Attribute") + attr.CreateAttr("saml:AttributeName", k) + attr.CreateAttr("saml:AttributeNamespace", "http://www.ja-sig.org/products/cas/") + attr.CreateElement("saml:AttributeValue").SetText(v) + } + } + + return samlResponse +} diff --git a/object/token_cas.go b/object/token_cas.go index 5e9fdca5674f..048fcf14aafb 100644 --- a/object/token_cas.go +++ b/object/token_cas.go @@ -15,14 +15,19 @@ package object import ( + "crypto" + "encoding/base64" "encoding/json" + "encoding/pem" "encoding/xml" "fmt" "math/rand" "sync" "time" + "github.com/beevik/etree" "github.com/casdoor/casdoor/util" + dsig "github.com/russellhaering/goxmldsig" ) type CasServiceResponse struct { @@ -84,6 +89,7 @@ type CasAnyAttribute struct { type CasAuthenticationSuccessWrapper struct { AuthenticationSuccess *CasAuthenticationSuccess // the token we issued Service string //to which service this token is issued + UserId string } type CasProxySuccess struct { @@ -96,17 +102,32 @@ type CasProxyFailure struct { Message string `xml:",innerxml"` } +type Saml11Request struct { + XMLName xml.Name `xml:"Request"` + SAMLP string `xml:"samlp,attr"` + MajorVersion string `xml:"MajorVersion,attr"` + MinorVersion string `xml:"MinorVersion,attr"` + RequestID string `xml:"RequestID,attr"` + IssueInstant string `xml:"IssueInstance,attr"` + AssertionArtifact Saml11AssertionArtifact +} +type Saml11AssertionArtifact struct { + XMLName xml.Name `xml:"AssertionArtifact"` + InnerXML string `xml:",innerxml"` +} + //st is short for service ticket var stToServiceResponse sync.Map //pgt is short for proxy granting ticket var pgtToServiceResponse sync.Map -func StoreCasTokenForPgt(token *CasAuthenticationSuccess, service string) string { +func StoreCasTokenForPgt(token *CasAuthenticationSuccess, service, userId string) string { pgt := fmt.Sprintf("PGT-%s", util.GenerateId()) pgtToServiceResponse.Store(pgt, &CasAuthenticationSuccessWrapper{ AuthenticationSuccess: token, Service: service, + UserId: userId, }) return pgt } @@ -115,33 +136,45 @@ func GenerateId() { panic("unimplemented") } -func GetCasTokenByPgt(pgt string) (bool, *CasAuthenticationSuccess, string) { +/** +@ret1: whether a token is found +@ret2: token, nil if not found +@ret3: the service URL who requested to issue this token +@ret4: userIf of user who requested to issue this token +*/ +func GetCasTokenByPgt(pgt string) (bool, *CasAuthenticationSuccess, string, string) { if responseWrapperType, ok := pgtToServiceResponse.LoadAndDelete(pgt); ok { responseWrapperTypeCast := responseWrapperType.(*CasAuthenticationSuccessWrapper) - return true, responseWrapperTypeCast.AuthenticationSuccess, responseWrapperTypeCast.Service + return true, responseWrapperTypeCast.AuthenticationSuccess, responseWrapperTypeCast.Service, responseWrapperTypeCast.UserId } - return false, nil, "" + return false, nil, "", "" } -func GetCasTokenByTicket(ticket string) (bool, *CasAuthenticationSuccess, string) { +/** +@ret1: whether a token is found +@ret2: token, nil if not found +@ret3: the service URL who requested to issue this token +@ret4: userIf of user who requested to issue this token +*/ +func GetCasTokenByTicket(ticket string) (bool, *CasAuthenticationSuccess, string, string) { if responseWrapperType, ok := stToServiceResponse.LoadAndDelete(ticket); ok { responseWrapperTypeCast := responseWrapperType.(*CasAuthenticationSuccessWrapper) - return true, responseWrapperTypeCast.AuthenticationSuccess, responseWrapperTypeCast.Service + return true, responseWrapperTypeCast.AuthenticationSuccess, responseWrapperTypeCast.Service, responseWrapperTypeCast.UserId } - return false, nil, "" + return false, nil, "", "" } -func StoreCasTokenForProxyTicket(token *CasAuthenticationSuccess, targetService string) string { +func StoreCasTokenForProxyTicket(token *CasAuthenticationSuccess, targetService, userId string) string { proxyTicket := fmt.Sprintf("PT-%s", util.GenerateId()) stToServiceResponse.Store(proxyTicket, &CasAuthenticationSuccessWrapper{ AuthenticationSuccess: token, Service: targetService, + UserId: userId, }) return proxyTicket } func GenerateCasToken(userId string, service string) (string, error) { - if user := GetUser(userId); user != nil { authenticationSuccess := CasAuthenticationSuccess{ User: user.Name, @@ -166,11 +199,69 @@ func GenerateCasToken(userId string, service string) (string, error) { stToServiceResponse.Store(st, &CasAuthenticationSuccessWrapper{ AuthenticationSuccess: &authenticationSuccess, Service: service, + UserId: userId, }) return st, nil } else { return "", fmt.Errorf("invalid user Id") } +} + +/** +@ret1: saml response +@ret2: the service URL who requested to issue this token +@ret3: error +*/ +func GetValidationBySaml(samlRequest string, host string) (string, string, error) { + var request Saml11Request + err := xml.Unmarshal([]byte(samlRequest), &request) + if err != nil { + return "", "", err + } + + ticket := request.AssertionArtifact.InnerXML + if ticket == "" { + return "", "", fmt.Errorf("samlp:AssertionArtifact field not found") + } + + ok, _, service, userId := GetCasTokenByTicket(ticket) + if !ok { + return "", "", fmt.Errorf("ticket %s found", ticket) + } + + user := GetUser(userId) + if user == nil { + return "", "", fmt.Errorf("user %s found", userId) + } + application := GetApplicationByUser(user) + if application == nil { + return "", "", fmt.Errorf("application for user %s found", userId) + } + + samlResponse := NewSamlResponse11(user, request.RequestID, host) + + cert := getCertByApplication(application) + block, _ := pem.Decode([]byte(cert.PublicKey)) + publicKey := base64.StdEncoding.EncodeToString(block.Bytes) + randomKeyStore := &X509Key{ + PrivateKey: cert.PrivateKey, + X509Certificate: publicKey, + } + + ctx := dsig.NewDefaultSigningContext(randomKeyStore) + ctx.Hash = crypto.SHA1 + signedXML, err := ctx.SignEnveloped(samlResponse) + if err != nil { + return "", "", fmt.Errorf("err: %s", err.Error()) + } + + doc := etree.NewDocument() + doc.SetRoot(signedXML) + xmlStr, err := doc.WriteToString() + if err != nil { + return "", "", fmt.Errorf("err: %s", err.Error()) + } + return xmlStr, service, nil } diff --git a/routers/authz_filter.go b/routers/authz_filter.go index 8db5f861c1e5..e8c75750d470 100644 --- a/routers/authz_filter.go +++ b/routers/authz_filter.go @@ -101,7 +101,7 @@ func willLog(subOwner string, subName string, method string, urlPath string, obj } func getUrlPath(urlPath string) string { - if strings.HasPrefix(urlPath, "/cas") && (strings.HasSuffix(urlPath, "/serviceValidate") || strings.HasSuffix(urlPath, "/proxy") || strings.HasSuffix(urlPath, "/proxyValidate") || strings.HasSuffix(urlPath, "/validate")) { + if strings.HasPrefix(urlPath, "/cas") && (strings.HasSuffix(urlPath, "/serviceValidate") || strings.HasSuffix(urlPath, "/proxy") || strings.HasSuffix(urlPath, "/proxyValidate") || strings.HasSuffix(urlPath, "/validate") || strings.HasSuffix(urlPath, "/p3/serviceValidate") || strings.HasSuffix(urlPath, "/p3/proxyValidate") || strings.HasSuffix(urlPath, "/samlValidate")) { return "/cas" } return urlPath diff --git a/routers/router.go b/routers/router.go index 184bafeef553..1b831d6dd4c4 100644 --- a/routers/router.go +++ b/routers/router.go @@ -173,9 +173,13 @@ func initAPI() { beego.Router("/.well-known/openid-configuration", &controllers.RootController{}, "GET:GetOidcDiscovery") beego.Router("/.well-known/jwks", &controllers.RootController{}, "*:GetJwks") - beego.Router("/cas/:organization/:application/serviceValidate", &controllers.RootController{}, "GET:CasServiceAndProxyValidate") - beego.Router("/cas/:organization/:application/proxyValidate", &controllers.RootController{}, "GET:CasServiceAndProxyValidate") + beego.Router("/cas/:organization/:application/serviceValidate", &controllers.RootController{}, "GET:CasServiceValidate") + beego.Router("/cas/:organization/:application/proxyValidate", &controllers.RootController{}, "GET:CasProxyValidate") beego.Router("/cas/:organization/:application/proxy", &controllers.RootController{}, "GET:CasProxy") beego.Router("/cas/:organization/:application/validate", &controllers.RootController{}, "GET:CasValidate") + beego.Router("/cas/:organization/:application/p3/serviceValidate", &controllers.RootController{}, "GET:CasP3ServiceAndProxyValidate") + beego.Router("/cas/:organization/:application/p3/proxyValidate", &controllers.RootController{}, "GET:CasP3ServiceAndProxyValidate") + beego.Router("/cas/:organization/:application/samlValidate", &controllers.RootController{}, "POST:SamlValidate") + } diff --git a/routers/static_filter.go b/routers/static_filter.go index b8f2ccc4ee7d..2e7bc094f26e 100644 --- a/routers/static_filter.go +++ b/routers/static_filter.go @@ -27,7 +27,7 @@ func StaticFilter(ctx *context.Context) { if strings.HasPrefix(urlPath, "/api/") || strings.HasPrefix(urlPath, "/.well-known/") { return } - if strings.HasPrefix(urlPath, "/cas") && (strings.HasSuffix(urlPath, "/serviceValidate") || strings.HasSuffix(urlPath, "/proxy") || strings.HasSuffix(urlPath, "/proxyValidate") || strings.HasSuffix(urlPath, "/validate")) { + if strings.HasPrefix(urlPath, "/cas") && (strings.HasSuffix(urlPath, "/serviceValidate") || strings.HasSuffix(urlPath, "/proxy") || strings.HasSuffix(urlPath, "/proxyValidate") || strings.HasSuffix(urlPath, "/validate") || strings.HasSuffix(urlPath, "/p3/serviceValidate") || strings.HasSuffix(urlPath, "/p3/proxyValidate") || strings.HasSuffix(urlPath, "/samlValidate")) { return }