diff --git a/auth.go b/auth.go index 51d122af3..f624cc098 100644 --- a/auth.go +++ b/auth.go @@ -35,13 +35,12 @@ import ( "sync" ) -type authInfo struct { - db, user, pass string -} - type authCmd struct { - Authenticate int - Nonce, User, Key string + Authenticate int + + Nonce string + User string + Key string } type startSaslCmd struct { @@ -67,6 +66,29 @@ type logoutCmd struct { Logout int } +type saslCmd struct { + Start int `bson:"saslStart,omitempty"` + Continue int `bson:"saslContinue,omitempty"` + ConversationId int `bson:"conversationId,omitempty"` + Mechanism string `bson:"mechanism,omitempty"` + Payload []byte +} + +type saslResult struct { + Ok bool `bson:"ok"` + NotOk bool `bson:"code"` // Server <= 2.3.2 returns ok=1 & code>0 on errors (WTF?) + Done bool + + ConversationId int `bson:"conversationId"` + Payload []byte + ErrMsg string +} + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + func (socket *mongoSocket) getNonce() (nonce string, err error) { socket.Lock() for socket.cachedNonce == "" && socket.dead == nil { @@ -133,25 +155,35 @@ func (socket *mongoSocket) resetNonce() { } } -func (socket *mongoSocket) Login(auth authInfo) error { +func (socket *mongoSocket) Login(cred Credential) error { socket.Lock() - for _, a := range socket.auth { - if a == auth { - debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, auth.db, auth.user) + for _, sockCred := range socket.creds { + if sockCred == cred { + debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, cred.Source, cred.Username) socket.Unlock() return nil } } - if socket.dropLogout(auth) { - debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, auth.db, auth.user) - socket.auth = append(socket.auth, auth) + if socket.dropLogout(cred) { + debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, cred.Source, cred.Username) + socket.creds = append(socket.creds, cred) socket.Unlock() return nil } socket.Unlock() - debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, auth.db, auth.user) - err := socket.loginClassic(auth) + debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, cred.Source, cred.Username) + + var err error + switch cred.Mechanism { + case "", "MONGO-CR": + err = socket.loginClassic(cred) + case "MONGO-X509": + err = fmt.Errorf("unsupported authentication mechanism: %s", cred.Mechanism) + default: + // Try SASL for everything else, if it is available. + err = socket.loginSASL(cred) + } if err != nil { debugf("Socket %p to %s: login error: %s", socket, socket.addr, err) @@ -161,7 +193,7 @@ func (socket *mongoSocket) Login(auth authInfo) error { return err } -func (socket *mongoSocket) loginClassic(auth authInfo) error { +func (socket *mongoSocket) loginClassic(cred Credential) error { // Note that this only works properly because this function is // synchronous, which means the nonce won't get reset while we're // using it and any other login requests will block waiting for a @@ -173,28 +205,101 @@ func (socket *mongoSocket) loginClassic(auth authInfo) error { defer socket.resetNonce() psum := md5.New() - psum.Write([]byte(auth.user + ":mongo:" + auth.pass)) + psum.Write([]byte(cred.Username + ":mongo:" + cred.Password)) ksum := md5.New() - ksum.Write([]byte(nonce + auth.user)) + ksum.Write([]byte(nonce + cred.Username)) ksum.Write([]byte(hex.EncodeToString(psum.Sum(nil)))) key := hex.EncodeToString(ksum.Sum(nil)) - cmd := authCmd{Authenticate: 1, User: auth.user, Nonce: nonce, Key: key} + cmd := authCmd{Authenticate: 1, User: cred.Username, Nonce: nonce, Key: key} res := authResult{} - return socket.loginRun(auth.db, &cmd, &res, func() error { + return socket.loginRun(cred.Source, &cmd, &res, func() error { if !res.Ok { return errors.New(res.ErrMsg) } socket.Lock() - socket.dropAuth(auth.db) - socket.auth = append(socket.auth, auth) + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) socket.Unlock() return nil }) } +func (socket *mongoSocket) loginSASL(cred Credential) error { + sasl, err := saslNew(cred, socket.Server().Addr) + if err != nil { + return err + } + defer sasl.Close() + + // The goal of this logic is to carry a locked socket until the + // local SASL step confirms the auth is valid; the socket needs to be + // locked so that concurrent action doesn't leave the socket in an + // auth state that doesn't reflect the operations that took place. + // As a simple case, imagine inverting login=>logout to logout=>login. + // + // The logic below works because the lock func isn't called concurrently. + locked := false + lock := func(b bool) { + if locked != b { + locked = b + if b { + socket.Lock() + } else { + socket.Unlock() + } + } + } + + lock(true) + defer lock(false) + + start := 1 + cmd := saslCmd{} + res := saslResult{} + for { + payload, done, err := sasl.Step(res.Payload) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + lock(false) + + cmd = saslCmd{ + Start: start, + Continue: 1 - start, + ConversationId: res.ConversationId, + Mechanism: cred.Mechanism, + Payload: payload, + } + start = 0 + err = socket.loginRun(cred.Source, &cmd, &res, func() error { + // See the comment on lock for why this is necessary. + lock(true) + if !res.Ok || res.NotOk { + return fmt.Errorf("server returned error on SASL authentication step: %s", res.ErrMsg) + } + return nil + }) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + } + + return nil +} + func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error { var mutex sync.Mutex var replyErr error @@ -232,20 +337,20 @@ func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func func (socket *mongoSocket) Logout(db string) { socket.Lock() - auth, found := socket.dropAuth(db) + cred, found := socket.dropAuth(db) if found { debugf("Socket %p to %s: logout: db=%q (flagged)", socket, socket.addr, db) - socket.logout = append(socket.logout, auth) + socket.logout = append(socket.logout, cred) } socket.Unlock() } func (socket *mongoSocket) LogoutAll() { socket.Lock() - if l := len(socket.auth); l > 0 { + if l := len(socket.creds); l > 0 { debugf("Socket %p to %s: logout all (flagged %d)", socket, socket.addr, l) - socket.logout = append(socket.logout, socket.auth...) - socket.auth = socket.auth[0:0] + socket.logout = append(socket.logout, socket.creds...) + socket.creds = socket.creds[0:0] } socket.Unlock() } @@ -257,7 +362,7 @@ func (socket *mongoSocket) flushLogout() (ops []interface{}) { for i := 0; i != l; i++ { op := queryOp{} op.query = &logoutCmd{1} - op.collection = socket.logout[i].db + ".$cmd" + op.collection = socket.logout[i].Source + ".$cmd" op.limit = -1 ops = append(ops, &op) } @@ -267,20 +372,20 @@ func (socket *mongoSocket) flushLogout() (ops []interface{}) { return } -func (socket *mongoSocket) dropAuth(db string) (auth authInfo, found bool) { - for i, a := range socket.auth { - if a.db == db { - copy(socket.auth[i:], socket.auth[i+1:]) - socket.auth = socket.auth[:len(socket.auth)-1] - return a, true +func (socket *mongoSocket) dropAuth(db string) (cred Credential, found bool) { + for i, sockCred := range socket.creds { + if sockCred.Source == db { + copy(socket.creds[i:], socket.creds[i+1:]) + socket.creds = socket.creds[:len(socket.creds)-1] + return sockCred, true } } - return auth, false + return cred, false } -func (socket *mongoSocket) dropLogout(auth authInfo) (found bool) { - for i, a := range socket.logout { - if a == auth { +func (socket *mongoSocket) dropLogout(cred Credential) (found bool) { + for i, sockCred := range socket.logout { + if sockCred == cred { copy(socket.logout[i:], socket.logout[i+1:]) socket.logout = socket.logout[:len(socket.logout)-1] return true diff --git a/auth_test.go b/auth_test.go index 9db821639..e46ea4e66 100644 --- a/auth_test.go +++ b/auth_test.go @@ -27,14 +27,16 @@ package mgo_test import ( + "flag" "fmt" "labix.org/v2/mgo" . "launchpad.net/gocheck" + "net/url" "sync" "time" ) -func (s *S) TestAuthLogin(c *C) { +func (s *S) TestAuthLoginDatabase(c *C) { // Test both with a normal database and with an authenticated shard. for _, addr := range []string{"localhost:40002", "localhost:40203"} { session, err := mgo.Dial(addr) @@ -58,6 +60,34 @@ func (s *S) TestAuthLogin(c *C) { } } +func (s *S) TestAuthLoginSession(c *C) { + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + cred := mgo.Credential{ + Username: "root", + Password: "wrong", + } + err = session.Login(&cred) + c.Assert(err, ErrorMatches, "auth fails") + + cred.Password = "rapadura" + + err = session.Login(&cred) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + } +} + func (s *S) TestAuthLoginLogout(c *C) { // Test both with a normal database and with an authenticated shard. for _, addr := range []string{"localhost:40002", "localhost:40203"} { @@ -707,15 +737,24 @@ func (s *S) TestAuthURLWithDatabase(c *C) { err = mydb.AddUser("myruser", "mypass", true) c.Assert(err, IsNil) - usession, err := mgo.Dial("mongodb://myruser:mypass@localhost:40002/mydb") - c.Assert(err, IsNil) - defer usession.Close() + // Test once with database, and once with source. + for i := 0; i < 2; i++ { + var url string + if i == 0 { + url = "mongodb://myruser:mypass@localhost:40002/mydb" + } else { + url = "mongodb://myruser:mypass@localhost:40002/admin?authSource=mydb" + } + usession, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer usession.Close() - ucoll := usession.DB("mydb").C("mycoll") - err = ucoll.FindId(0).One(nil) - c.Assert(err, Equals, mgo.ErrNotFound) - err = ucoll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + ucoll := usession.DB("mydb").C("mycoll") + err = ucoll.FindId(0).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + err = ucoll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + } } func (s *S) TestDefaultDatabase(c *C) { @@ -776,3 +815,49 @@ func (s *S) TestAuthDirectWithLogin(c *C) { c.Assert(err, Equals, mgo.ErrNotFound) } } + +var ( + kerberosFlag = flag.Bool("kerberos", false, "Test Kerberos authentication (depends on custom environment)") + kerberosHost = "mmscustmongo.10gen.me" + kerberosUser = "mmsagent/mmscustagent.10gen.me@10GEN.ME" +) + +func (s *S) TestAuthKerberosCred(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + cred := &mgo.Credential{ + Username: kerberosUser, + Mechanism: "GSSAPI", + } + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(kerberosHost) + defer session.Close() + + c.Logf("Connected! Testing the need for authentication...") + c.Assert(err, IsNil) + names, err := session.DatabaseNames() + c.Assert(err, ErrorMatches, "unauthorized") + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + names, err = session.DatabaseNames() + c.Assert(err, IsNil) + c.Assert(len(names) > 0, Equals, true) +} + +func (s *S) TestAuthKerberosURL(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(url.QueryEscape(kerberosUser) + "@" + kerberosHost + "?authMechanism=GSSAPI") + c.Assert(err, IsNil) + defer session.Close() + names, err := session.DatabaseNames() + c.Assert(err, IsNil) + c.Assert(len(names) > 0, Equals, true) +} diff --git a/sasl/sasl.c b/sasl/sasl.c new file mode 100644 index 000000000..87c17c69a --- /dev/null +++ b/sasl/sasl.c @@ -0,0 +1,75 @@ +#include +#include +#include +#include + +static int mgo_sasl_simple(void *context, int id, const char **result, unsigned int *len) +{ + if (!result) { + return SASL_BADPARAM; + } + switch (id) { + case SASL_CB_USER: + *result = (char *) context; + break; + case SASL_CB_AUTHNAME: + *result = (char *) context; + break; + case SASL_CB_LANGUAGE: + *result = NULL; + break; + default: + return SASL_BADPARAM; + } + if (len) { + *len = *result ? strlen(*result) : 0; + } + return SASL_OK; +} + +typedef int (*callback)(void); + +static int mgo_sasl_secret(sasl_conn_t *conn, void *context, int id, sasl_secret_t **result) +{ + if (!conn || !result || id != SASL_CB_PASS) { + return SASL_BADPARAM; + } + *result = (sasl_secret_t *)context; + return SASL_OK; +} + +sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password) +{ + sasl_callback_t *cb = malloc(4 * sizeof(sasl_callback_t)); + int n = 0; + + size_t len = strlen(password); + sasl_secret_t *secret = (sasl_secret_t*)malloc(sizeof(sasl_secret_t) + len); + if (!secret) { + free(cb); + return NULL; + } + strcpy((char *)secret->data, password); + secret->len = len; + + cb[n].id = SASL_CB_PASS; + cb[n].proc = (callback)&mgo_sasl_secret; + cb[n].context = secret; + n++; + + cb[n].id = SASL_CB_USER; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_AUTHNAME; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_LIST_END; + cb[n].proc = NULL; + cb[n].context = NULL; + + return cb; +} diff --git a/sasl/sasl.go b/sasl/sasl.go new file mode 100644 index 000000000..e4a170463 --- /dev/null +++ b/sasl/sasl.go @@ -0,0 +1,135 @@ +// Package sasl is an implementation detail of the mgo package. +// +// This package is not meant to be used by itself. +// +package sasl + +// #cgo LDFLAGS: -lsasl2 +// +// struct sasl_conn {}; +// +// #include +// #include +// +// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password); +// +import "C" + +import ( + "fmt" + "strings" + "sync" + "unsafe" +) + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +type saslSession struct { + conn *C.sasl_conn_t + step int + mech string + + cstrings []*C.char + callbacks *C.sasl_callback_t +} + +var initError error +var initOnce sync.Once + +func initSASL() { + rc := C.sasl_client_init(nil) + if rc != C.SASL_OK { + initError = saslError(rc, nil, "cannot initialize SASL library") + } +} + +func New(username, password, mechanism, service, host string) (saslStepper, error) { + initOnce.Do(initSASL) + if initError != nil { + return nil, initError + } + + ss := &saslSession{mech: mechanism} + if service == "" { + service = "mongodb" + } + if i := strings.Index(host, ":"); i >= 0 { + host = host[:i] + } + ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password)) + rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn) + if rc != C.SASL_OK { + ss.Close() + return nil, saslError(rc, nil, "cannot create new SASL client") + } + return ss, nil +} + +func (ss *saslSession) cstr(s string) *C.char { + cstr := C.CString(s) + ss.cstrings = append(ss.cstrings, cstr) + return cstr +} + +func (ss *saslSession) Close() { + for _, cstr := range ss.cstrings { + C.free(unsafe.Pointer(cstr)) + } + ss.cstrings = nil + + if ss.callbacks != nil { + C.free(unsafe.Pointer(ss.callbacks)) + } + + // The documentation of SASL dispose makes it clear that this should only + // be done when the connection is done, not when the authentication phase + // is done, because an encryption layer may have been negotiated. + // Even then, we'll do this for now, because it's simpler and prevents + // keeping track of this state for every socket. If it breaks, we'll fix it. + C.sasl_dispose(&ss.conn) +} + +func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { + ss.step++ + if ss.step > 10 { + return nil, false, fmt.Errorf("too many SASL steps without authentication") + } + var cclientData *C.char + var cclientDataLen C.uint + var rc C.int + if ss.step == 1 { + var mechanism *C.char // ignored - must match cred + rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism) + } else { + var cserverData *C.char + var cserverDataLen C.uint + if len(serverData) > 0 { + cserverData = (*C.char)(unsafe.Pointer(&serverData[0])) + cserverDataLen = C.uint(len(serverData)) + } + rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen) + } + if cclientData != nil && cclientDataLen > 0 { + clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen)) + } + if rc == C.SASL_OK { + return clientData, true, nil + } + if rc == C.SASL_CONTINUE { + return clientData, false, nil + } + return nil, false, saslError(rc, ss.conn, "cannot establish SASL session") +} + +func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error { + var detail string + if conn == nil { + detail = C.GoString(C.sasl_errstring(rc, nil, nil)) + } else { + detail = C.GoString(C.sasl_errdetail(conn)) + } + return fmt.Errorf(msg + ": " + detail) +} diff --git a/saslimpl.go b/saslimpl.go new file mode 100644 index 000000000..3b255def6 --- /dev/null +++ b/saslimpl.go @@ -0,0 +1,11 @@ +//+build sasl + +package mgo + +import ( + "labix.org/v2/mgo/sasl" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return sasl.New(cred.Username, cred.Password, cred.Mechanism, cred.Service, host) +} diff --git a/saslstub.go b/saslstub.go new file mode 100644 index 000000000..6e9e30986 --- /dev/null +++ b/saslstub.go @@ -0,0 +1,11 @@ +//+build !sasl + +package mgo + +import ( + "fmt" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return nil, fmt.Errorf("SASL support not enabled during build (-tags sasl)") +} diff --git a/session.go b/session.go index bb37a950b..48643b8b8 100644 --- a/session.go +++ b/session.go @@ -34,6 +34,7 @@ import ( "labix.org/v2/mgo/bson" "math" "net" + "net/url" "reflect" "sort" "strconv" @@ -65,8 +66,9 @@ type Session struct { syncTimeout time.Duration sockTimeout time.Duration defaultdb string - dialAuth *authInfo - auth []authInfo + sourcedb string + dialCred *Credential + creds []Credential } type Database struct { @@ -173,7 +175,7 @@ const defaultPrefetch = 0.25 // http://www.mongodb.org/display/DOCS/Connections // func Dial(url string) (*Session, error) { - session, err := DialWithTimeout(url, 10 * time.Second) + session, err := DialWithTimeout(url, 10*time.Second) if err == nil { session.SetSyncTimeout(1 * time.Minute) session.SetSocketTimeout(1 * time.Minute) @@ -193,8 +195,17 @@ func DialWithTimeout(url string, timeout time.Duration) (*Session, error) { return nil, err } direct := false + mechanism := "" + service := "" + source := "" for k, v := range uinfo.options { switch k { + case "authSource": + source = v + case "authMechanism": + mechanism = v + case "gssapiServiceName": + service = v case "connect": if v == "direct" { direct = true @@ -205,16 +216,19 @@ func DialWithTimeout(url string, timeout time.Duration) (*Session, error) { } fallthrough default: - return nil, errors.New("Unsupported connection URL option: " + k + "=" + v) + return nil, errors.New("unsupported connection URL option: " + k + "=" + v) } } info := DialInfo{ - Addrs: uinfo.addrs, - Direct: direct, - Timeout: timeout, - Username: uinfo.user, - Password: uinfo.pass, - Database: uinfo.db, + Addrs: uinfo.addrs, + Direct: direct, + Timeout: timeout, + Database: uinfo.db, + Username: uinfo.user, + Password: uinfo.pass, + Mechanism: mechanism, + Service: service, + Source: source, } return DialWithInfo(&info) } @@ -243,14 +257,26 @@ type DialInfo struct { // distinguish it from a slow server, so the timeout stays relevant. FailFast bool - // Database is the database name used during the initial authentication. - // If set, the value is also returned as the default result from the - // Session.DB method, in place of "test". + // Database is the default database name used when the Session.DB method + // is called with an empty name, and is also used during the intial + // authenticatoin if Source is unset. Database string - // Username and Password inform the credentials for the initial - // authentication done against Database, if that is set, - // or the "admin" database otherwise. See the Session.Login method too. + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the value of Database, if that is + // set, or "admin" otherwise. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string + + // Username and Password inform the credentials for the initial authentication + // done against the database defined by the Source field. See Session.Login. Username string Password string @@ -296,13 +322,26 @@ func DialWithInfo(info *DialInfo) (*Session, error) { if session.defaultdb == "" { session.defaultdb = "test" } + session.sourcedb = info.Source + if session.sourcedb == "" { + session.sourcedb = info.Database + if session.sourcedb == "" { + session.sourcedb = "admin" + } + } if info.Username != "" { - db := info.Database - if db == "" { - db = "admin" + source := session.sourcedb + if info.Source == "" && info.Mechanism == "GSSAPI" { + source = "$external" } - session.dialAuth = &authInfo{db, info.Username, info.Password} - session.auth = []authInfo{*session.dialAuth} + session.dialCred = &Credential{ + Username: info.Username, + Password: info.Password, + Mechanism: info.Mechanism, + Service: info.Service, + Source: source, + } + session.creds = []Credential{*session.dialCred} } cluster.Release() @@ -330,35 +369,44 @@ type urlInfo struct { options map[string]string } -func parseURL(url string) (*urlInfo, error) { - if strings.HasPrefix(url, "mongodb://") { - url = url[10:] +func parseURL(s string) (*urlInfo, error) { + if strings.HasPrefix(s, "mongodb://") { + s = s[10:] } info := &urlInfo{options: make(map[string]string)} - if c := strings.Index(url, "?"); c != -1 { - for _, pair := range strings.FieldsFunc(url[c+1:], isOptSep) { + if c := strings.Index(s, "?"); c != -1 { + for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) { l := strings.SplitN(pair, "=", 2) if len(l) != 2 || l[0] == "" || l[1] == "" { return nil, errors.New("Connection option must be key=value: " + pair) } info.options[l[0]] = l[1] } - url = url[:c] + s = s[:c] } - if c := strings.Index(url, "@"); c != -1 { - pair := strings.SplitN(url[:c], ":", 2) - if len(pair) != 2 || pair[0] == "" { + if c := strings.Index(s, "@"); c != -1 { + pair := strings.SplitN(s[:c], ":", 2) + if len(pair) > 2 || pair[0] == "" { return nil, errors.New("Credentials must be provided as user:pass@host") } - info.user = pair[0] - info.pass = pair[1] - url = url[c+1:] + var err error + info.user, err = url.QueryUnescape(pair[0]) + if err != nil { + return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0]) + } + if len(pair) > 1 { + info.pass, err = url.QueryUnescape(pair[1]) + if err != nil { + return nil, fmt.Errorf("cannot unescape password in URL") + } + } + s = s[c+1:] } - if c := strings.Index(url, "/"); c != -1 { - info.db = url[c+1:] - url = url[:c] + if c := strings.Index(s, "/"); c != -1 { + info.db = s[c+1:] + s = s[:c] } - info.addrs = strings.Split(url, ",") + info.addrs = strings.Split(s, ",") return info, nil } @@ -372,7 +420,7 @@ func newSession(consistency mode, cluster *mongoCluster, timeout time.Duration) return session } -func copySession(session *Session, keepAuth bool) (s *Session) { +func copySession(session *Session, keepCreds bool) (s *Session) { cluster := session.cluster() cluster.Acquire() if session.masterSocket != nil { @@ -381,16 +429,16 @@ func copySession(session *Session, keepAuth bool) (s *Session) { if session.slaveSocket != nil { session.slaveSocket.Acquire() } - var auth []authInfo - if keepAuth { - auth = make([]authInfo, len(session.auth)) - copy(auth, session.auth) - } else if session.dialAuth != nil { - auth = []authInfo{*session.dialAuth} + var creds []Credential + if keepCreds { + creds = make([]Credential, len(session.creds)) + copy(creds, session.creds) + } else if session.dialCred != nil { + creds = []Credential{*session.dialCred} } scopy := *session scopy.m = sync.RWMutex{} - scopy.auth = auth + scopy.creds = creds s = &scopy debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) return s @@ -488,45 +536,68 @@ func (db *Database) Run(cmd interface{}, result interface{}) error { return db.C("$cmd").Find(cmd).One(result) } -// Login authenticates against MongoDB with the provided credentials. The +// Credential holds details to authenticate with a MongoDB server. +type Credential struct { + // Username and Password hold the basic details for authentication. + // Password is optional with some authentication mechanisms. + Username string + Password string + + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the default database provided + // during dial, or "admin" if that was unset. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string +} + +// Login authenticates against MongoDB with the provided credential. The // authentication is valid for the whole session and will stay valid until // Logout is explicitly called for the same database, or the session is // closed. -// -// Concurrent Login calls will work correctly. -func (db *Database) Login(user, pass string) (err error) { - session := db.Session - dbname := db.Name +func (db *Database) Login(user, pass string) error { + return db.Session.Login(&Credential{Username: user, Password: pass, Source: db.Name}) +} - socket, err := session.acquireSocket(true) +// Login authenticates against MongoDB with the provided credential. The +// authentication is valid for the whole session and will stay valid until +// Logout is explicitly called for the same database, or the session is +// closed. +func (s *Session) Login(cred *Credential) error { + socket, err := s.acquireSocket(true) if err != nil { return err } defer socket.Release() - auth := authInfo{dbname, user, pass} - err = socket.Login(auth) + credCopy := *cred + if cred.Source == "" { + if cred.Mechanism == "GSSAPI" { + credCopy.Source = "$external" + } else { + credCopy.Source = s.sourcedb + } + } + err = socket.Login(credCopy) if err != nil { return err } - session.m.Lock() - defer session.m.Unlock() - - for _, a := range session.auth { - if a.db == dbname { - a.user = user - a.pass = pass - return nil - } - } - session.auth = append(session.auth, auth) + s.m.Lock() + s.creds = append(s.creds, credCopy) + s.m.Unlock() return nil } func (s *Session) socketLogin(socket *mongoSocket) error { - for _, auth := range s.auth { - if err := socket.Login(auth); err != nil { + for _, cred := range s.creds { + if err := socket.Login(cred); err != nil { return err } } @@ -539,10 +610,10 @@ func (db *Database) Logout() { dbname := db.Name session.m.Lock() found := false - for i, a := range session.auth { - if a.db == dbname { - copy(session.auth[i:], session.auth[i+1:]) - session.auth = session.auth[:len(session.auth)-1] + for i, cred := range session.creds { + if cred.Source == dbname { + copy(session.creds[i:], session.creds[i+1:]) + session.creds = session.creds[:len(session.creds)-1] found = true break } @@ -561,15 +632,15 @@ func (db *Database) Logout() { // LogoutAll removes all established authentication credentials for the session. func (s *Session) LogoutAll() { s.m.Lock() - for _, a := range s.auth { + for _, cred := range s.creds { if s.masterSocket != nil { - s.masterSocket.Logout(a.db) + s.masterSocket.Logout(cred.Source) } if s.slaveSocket != nil { - s.slaveSocket.Logout(a.db) + s.slaveSocket.Logout(cred.Source) } } - s.auth = s.auth[0:0] + s.creds = s.creds[0:0] s.m.Unlock() } diff --git a/socket.go b/socket.go index 528f6354e..fc5712de1 100644 --- a/socket.go +++ b/socket.go @@ -45,8 +45,8 @@ type mongoSocket struct { nextRequestId uint32 replyFuncs map[uint32]replyFunc references int - auth []authInfo - logout []authInfo + creds []Credential + logout []Credential cachedNonce string gotNonce sync.Cond dead error @@ -547,7 +547,9 @@ func (socket *mongoSocket) readLoop() { for i := 0; i != int(reply.replyDocs); i++ { err := fill(conn, s) if err != nil { - replyFunc(err, nil, -1, nil) + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } socket.kill(err, true) return } @@ -562,7 +564,9 @@ func (socket *mongoSocket) readLoop() { err = fill(conn, b[4:]) if err != nil { - replyFunc(err, nil, -1, nil) + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } socket.kill(err, true) return }