Skip to content

Commit

Permalink
feat: implement CAS 3.0 (casdoor#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
ComradeProgrammer authored Apr 11, 2022
1 parent 15daf5d commit 7236cca
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 22 deletions.
83 changes: 75 additions & 8 deletions controllers/cas.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package controllers

import (
"encoding/xml"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -42,7 +43,7 @@ func (c *RootController) CasValidate() {
c.Ctx.Output.Body([]byte("no\n"))
return
}
if ok, response, issuedService := object.GetCasTokenByTicket(ticket); ok {
if ok, response, issuedService, _ := object.GetCasTokenByTicket(ticket); ok {
//check whether service is the one for which we previously issued token
if issuedService == service {
c.Ctx.Output.Body([]byte(fmt.Sprintf("yes\n%s\n", response.User)))
Expand All @@ -54,7 +55,25 @@ func (c *RootController) CasValidate() {
c.Ctx.Output.Body([]byte("no\n"))
}

func (c *RootController) CasServiceAndProxyValidate() {
func (c *RootController) CasServiceValidate() {
ticket := c.Input().Get("ticket")
format := c.Input().Get("format")
if !strings.HasPrefix(ticket, "ST") {
c.sendCasAuthenticationResponseErr(InvalidTicket, fmt.Sprintf("Ticket %s not recognized", ticket), format)
}
c.CasP3ServiceAndProxyValidate()
}

func (c *RootController) CasProxyValidate() {
ticket := c.Input().Get("ticket")
format := c.Input().Get("format")
if !strings.HasPrefix(ticket, "PT") {
c.sendCasAuthenticationResponseErr(InvalidTicket, fmt.Sprintf("Ticket %s not recognized", ticket), format)
}
c.CasP3ServiceAndProxyValidate()
}

func (c *RootController) CasP3ServiceAndProxyValidate() {
ticket := c.Input().Get("ticket")
format := c.Input().Get("format")
service := c.Input().Get("service")
Expand All @@ -69,10 +88,9 @@ func (c *RootController) CasServiceAndProxyValidate() {
c.sendCasAuthenticationResponseErr(InvalidRequest, "service and ticket must exist", format)
return
}

ok, response, issuedService, userId := object.GetCasTokenByTicket(ticket)
//find the token
if ok, response, issuedService := object.GetCasTokenByTicket(ticket); ok {

if ok {
//check whether service is the one for which we previously issued token
if strings.HasPrefix(service, issuedService) {
serviceResponse.Success = response
Expand All @@ -89,7 +107,7 @@ func (c *RootController) CasServiceAndProxyValidate() {

if pgtUrl != "" && serviceResponse.Failure == nil {
//that means we are in proxy web flow
pgt := object.StoreCasTokenForPgt(serviceResponse.Success, service)
pgt := object.StoreCasTokenForPgt(serviceResponse.Success, service, userId)
pgtiou := serviceResponse.Success.ProxyGrantingTicket
//todo: check whether it is https
pgtUrlObj, err := url.Parse(pgtUrl)
Expand Down Expand Up @@ -139,7 +157,7 @@ func (c *RootController) CasProxy() {
return
}

ok, authenticationSuccess, issuedService := object.GetCasTokenByPgt(pgt)
ok, authenticationSuccess, issuedService, userId := object.GetCasTokenByPgt(pgt)
if !ok {
c.sendCasProxyResponseErr(UnauthorizedService, "service not authorized", format)
return
Expand All @@ -150,7 +168,7 @@ func (c *RootController) CasProxy() {
newAuthenticationSuccess.Proxies = &object.CasProxies{}
}
newAuthenticationSuccess.Proxies.Proxies = append(newAuthenticationSuccess.Proxies.Proxies, issuedService)
proxyTicket := object.StoreCasTokenForProxyTicket(&newAuthenticationSuccess, targetService)
proxyTicket := object.StoreCasTokenForProxyTicket(&newAuthenticationSuccess, targetService, userId)

serviceResponse := object.CasServiceResponse{
Xmlns: "http://www.yale.edu/tp/cas",
Expand All @@ -168,6 +186,55 @@ func (c *RootController) CasProxy() {
}

}

func (c *RootController) SamlValidate() {
c.Ctx.Output.Header("Content-Type", "text/xml; charset=utf-8")
target := c.Input().Get("TARGET")
body := c.Ctx.Input.RequestBody
envelopRequest := struct {
XMLName xml.Name `xml:"Envelope"`
Body struct {
XMLName xml.Name `xml:"Body"`
Content string `xml:",innerxml"`
}
}{}

err := xml.Unmarshal(body, &envelopRequest)
if err != nil {
c.ResponseError(err.Error())
return
}

response, service, err := object.GetValidationBySaml(envelopRequest.Body.Content, c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}

if !strings.HasPrefix(target, service) {
c.ResponseError(fmt.Sprintf("service %s and %s do not match", target, service))
return
}

envelopReponse := struct {
XMLName xml.Name `xml:"SOAP-ENV:Envelope"`
Xmlns string `xml:"xmlns:SOAP-ENV"`
Body struct {
XMLName xml.Name `xml:"SOAP-ENV:Body"`
Content string `xml:",innerxml"`
}
}{}
envelopReponse.Xmlns = "http://schemas.xmlsoap.org/soap/envelope/"
envelopReponse.Body.Content = response

data, err := xml.Marshal(envelopReponse)
if err != nil {
c.ResponseError(err.Error())
return
}
c.Ctx.Output.Body([]byte(data))
}

func (c *RootController) sendCasProxyResponseErr(code, msg, format string) {
serviceResponse := object.CasServiceResponse{
Xmlns: "http://www.yale.edu/tp/cas",
Expand Down
80 changes: 79 additions & 1 deletion object/saml_idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"crypto"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"encoding/pem"
"encoding/xml"
"fmt"
Expand All @@ -34,6 +35,7 @@ import (
uuid "github.com/satori/go.uuid"
)

//returns a saml2 response
func NewSamlResponse(user *User, host string, publicKey string, destination string, iss string, redirectUri []string) (*etree.Element, error) {
samlResponse := &etree.Element{
Space: "samlp",
Expand Down Expand Up @@ -223,7 +225,8 @@ func GetSamlMeta(application *Application, host string) (*IdpEntityDescriptor, e
return &d, nil
}

//GenerateSamlResponse generates a SAML response
//GenerateSamlResponse generates a SAML2.0 response
//parameter samlRequest is saml request in base64 format
func GetSamlResponse(application *Application, user *User, samlRequest string, host string) (string, string, error) {
//decode samlRequest
defated, err := base64.StdEncoding.DecodeString(samlRequest)
Expand Down Expand Up @@ -272,3 +275,78 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
res := base64.StdEncoding.EncodeToString([]byte(xmlStr))
return res, authnRequest.AssertionConsumerServiceURL, nil
}

//return a saml1.1 response(not 2.0)
func NewSamlResponse11(user *User, requestID string, host string) *etree.Element {
samlResponse := &etree.Element{
Space: "samlp",
Tag: "Response",
}
//create samlresponse
samlResponse.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:1.0:protocol")
samlResponse.CreateAttr("MajorVersion", "1")
samlResponse.CreateAttr("MinorVersion", "1")

responseID := uuid.NewV4()
samlResponse.CreateAttr("ResponseID", fmt.Sprintf("_%s", responseID))
samlResponse.CreateAttr("InResponseTo", requestID)

now := time.Now().UTC().Format(time.RFC3339)
expireTime := time.Now().UTC().Add(time.Hour * 24).Format(time.RFC3339)

samlResponse.CreateAttr("IssueInstant", now)

samlResponse.CreateElement("samlp:Status").CreateElement("samlp:StatusCode").CreateAttr("Value", "samlp:Success")

//create assertion which is inside the response
assertion := samlResponse.CreateElement("saml:Assertion")
assertion.CreateAttr("xmlns:saml", "urn:oasis:names:tc:SAML:1.0:assertion")
assertion.CreateAttr("MajorVersion", "1")
assertion.CreateAttr("MinorVersion", "1")
assertion.CreateAttr("AssertionID", uuid.NewV4().String())
assertion.CreateAttr("Issuer", host)
assertion.CreateAttr("IssueInstant", now)

condition := assertion.CreateElement("saml:Conditions")
condition.CreateAttr("NotBefore", now)
condition.CreateAttr("NotOnOrAfter", expireTime)

//AuthenticationStatement inside assertion
authenticationStatement := assertion.CreateElement("saml:AuthenticationStatement")
authenticationStatement.CreateAttr("AuthenticationMethod", "urn:oasis:names:tc:SAML:1.0:am:password")
authenticationStatement.CreateAttr("AuthenticationInstant", now)

//subject inside AuthenticationStatement
subject := assertion.CreateElement("saml:Subject")
//nameIdentifier inside subject
nameIdentifier := subject.CreateElement("saml:NameIdentifier")
//nameIdentifier.CreateAttr("Format", "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress")
nameIdentifier.SetText(user.Name)

//subjectConfirmation inside subject
subjectConfirmation := subject.CreateElement("saml:SubjectConfirmation")
subjectConfirmation.CreateElement("saml:ConfirmationMethod").SetText("urn:oasis:names:tc:SAML:1.0:cm:artifact")

attributeStatement := assertion.CreateElement("saml:AttributeStatement")
subjectInAttribute := attributeStatement.CreateElement("saml:Subject")
nameIdentifierInAttribute := subjectInAttribute.CreateElement("saml:NameIdentifier")
nameIdentifierInAttribute.SetText(user.Name)

subjectConfirmationInAttribute := subjectInAttribute.CreateElement("saml:SubjectConfirmation")
subjectConfirmationInAttribute.CreateElement("saml:ConfirmationMethod").SetText("urn:oasis:names:tc:SAML:1.0:cm:artifact")

data, _ := json.Marshal(user)
tmp := map[string]string{}
json.Unmarshal(data, &tmp)

for k, v := range tmp {
if v != "" {
attr := attributeStatement.CreateElement("saml:Attribute")
attr.CreateAttr("saml:AttributeName", k)
attr.CreateAttr("saml:AttributeNamespace", "http://www.ja-sig.org/products/cas/")
attr.CreateElement("saml:AttributeValue").SetText(v)
}
}

return samlResponse
}
Loading

0 comments on commit 7236cca

Please sign in to comment.