Skip to content

Commit

Permalink
✨Add LDAP Authentication
Browse files Browse the repository at this point in the history
Signed-off-by: Stany MARCEL <[email protected]>
  • Loading branch information
ynsta committed May 16, 2020
1 parent 9f7c8bf commit 367b0f8
Show file tree
Hide file tree
Showing 5 changed files with 460 additions and 29 deletions.
119 changes: 119 additions & 0 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package auth

import (
"database/sql"
"fmt"

"github.com/go-shiori/shiori/internal/database"
"github.com/go-shiori/shiori/internal/env"
"golang.org/x/crypto/bcrypt"
)

// Status is the authentication status returned by Check
type Status int

const (
// Unauthorized username or wrong password
Unauthorized Status = iota
// Visitor means that the user is authorized as a visitor
Visitor
// Owner means that the user is authorized as an owner
Owner
)

var (
ldapAuth *LDAPAuth
)

func init() {
if env.GetEnvBool("SHIORI_AUTH_LDAP", false) {
tmp, err := NewLDAPAuth(LDAPSettings{
Host: env.GetEnvString("SHIORI_AUTH_LDAP_HOST", "ldap"),
Port: int(env.GetEnvInt64("SHIORI_AUTH_LDAP_PORT", 389)),
StartTLS: env.GetEnvBool("SHIORI_AUTH_LDAP_TLS_ENABLED", true),
SkipCertificateVerif: env.GetEnvBool("SHIORI_AUTH_LDAP_TLS_SKIP_VERIF", true),
ThrustedCertificates: env.GetEnvStringList("SHIORI_AUTH_LDAP_TLS_THRUSTED_CERTIFICATES", []string{}),
UserGroupFilter: env.GetEnvString(
"SHIORI_AUTH_LDAP_SEARCH_FILTER",
"(&(|(mail={{.Login}})(sAMAccountName={{.Login}}))(memberOf={{.Group}}))",
),
BaseDN: env.GetEnvString("SHIORI_AUTH_LDAP_SEARCH_BASE", ""),
BindDN: env.GetEnvString("SHIORI_AUTH_LDAP_BIND_USERDN", ""),
BindDNPassword: env.GetEnvString("SHIORI_AUTH_LDAP_BIND_PASSWORD", ""),
})
if err == nil {
ldapAuth = &tmp
}
}
}

// Check username, password with configured auth methods
func Check(username string, password string, db database.DB) (Status, string) {

if ldapAuth != nil {
loginField := env.GetEnvString("SHIORI_AUTH_LDAP_LOGIN_FIELD", "sAMAccountName")
ownerGroup := env.GetEnvString("SHIORI_AUTH_LDAP_OWNER_GROUP", "")
visitorGroup := env.GetEnvString("SHIORI_AUTH_LDAP_VISITOR_GROUP", "")
oDN, oLogin, oErr := ldapAuth.Search(
username,
ownerGroup,
loginField,
)
vDN, vLogin, vErr := ldapAuth.Search(
username,
visitorGroup,
loginField,
)

if oErr == nil {
fmt.Printf("LDAP: owner found: %s\n", oDN)
if ldapAuth.VerifyDN(oDN, password) == nil {
return Owner, oLogin
}
} else if vErr == nil {
fmt.Printf("LDAP: visitor found: %s\n", vDN)
if ldapAuth.VerifyDN(vDN, password) == nil {
return Visitor, vLogin
}
}
fmt.Printf("LDAP: not found (%v, %v)\n", oErr, vErr)
}

defaultUser := env.GetEnvString("SHIORI_DEFAULT_USER", "shiori")
defaultPassword := env.GetEnvString("SHIORI_DEFAULT_PASSWORD", "gopher")

// Check if user's database is empty or there are no owner.
// If yes, and user uses default account, let him in.
searchOptions := database.GetAccountsOptions{
Owner: true,
}

accounts, err := db.GetAccounts(searchOptions)
if err != nil && err != sql.ErrNoRows {
panic(err)
}

if len(accounts) == 0 && username == defaultUser && password == defaultPassword {
return Owner, defaultUser
}

// Get account data from database
account, exist := db.GetAccount(username)
hash := ""
if exist {
hash = account.Password
}

// Compare password with database
err = bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
if hash == "" || err != nil {
return Unauthorized, username
}

// If login request is as owner, make sure this account is owner
if account.Owner {
return Owner, username
}
return Visitor, username

}
161 changes: 161 additions & 0 deletions internal/auth/ldapauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package auth

import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"text/template"

"github.com/go-ldap/ldap"
)

// LDAPAuth class
type LDAPAuth struct {
settings LDAPSettings
certs *x509.CertPool
}

// LDAPSettings used for auth object creation
type LDAPSettings struct {
Host string // ldap host use full hostname if tls is used
Port int // ldap port default is 389
StartTLS bool // Start TLS, Disable if not supported by server but credentials will transit without any encryption
SkipCertificateVerif bool // Skip certificate verification only use for debug purpose
ThrustedCertificates []string // List of thrusted CA and certificates
UserGroupFilter string // Filter used to search provided user{{.Login}} & group{{.Group}}
BaseDN string // Base DN for users
BindDN string // DN used to bind for search operations
BindDNPassword string // DN credential used to bind for search operations
}

// NewLDAPAuth returns a ldap auth object from given setting
func NewLDAPAuth(settings LDAPSettings) (LDAPAuth, error) {
la := LDAPAuth{
settings: settings,
}

la.certs = x509.NewCertPool()
for _, cert := range la.settings.ThrustedCertificates {
if data, err := ioutil.ReadFile(cert); err == nil {
if !la.certs.AppendCertsFromPEM(data) {
log.Println("ERROR: LDAP Unable to load certificate " + cert)
}
} else {
log.Println("ERROR: LDAP Unable to read certificate " + cert + " " + err.Error())
}

}
return la, nil
}

func (la *LDAPAuth) connect() (*ldap.Conn, error) {
// log.Println("LDAP:connect: Dial")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", la.settings.Host, la.settings.Port))
if err != nil {
return nil, err
}

if la.settings.StartTLS {
// Reconnect with TLS
var tlsConfig tls.Config
if la.settings.SkipCertificateVerif {
log.Println("WARNING: LDAP LDAPAuth with TLS without certificate verification")
tlsConfig = tls.Config{InsecureSkipVerify: true}
} else {
tlsConfig = tls.Config{
ServerName: la.settings.Host,
InsecureSkipVerify: false,
RootCAs: la.certs,
}
}
// log.Println("LDAP:connect: StartTLS")
err = l.StartTLS(&tlsConfig)
if err != nil {
return nil, err
}
} else {
log.Println("WARNING: LDAP LDAPAuth without TLS")
}

// Bind with read only user
// log.Println("LDAP:connect: Bind as", la.settings.BindDN)
err = l.Bind(la.settings.BindDN, la.settings.BindDNPassword)
if err != nil {
return nil, err
}
// log.Println("LDAP:connect: Done")
return l, nil
}

func (la *LDAPAuth) search(l *ldap.Conn, username string, group string, loginField string) (string, string, error) {

// Generate filter from username and group
type Search struct {
Login string
Group string
}

data := Search{
Login: username,
Group: group,
}
t := template.Must(template.New("filter").Parse(la.settings.UserGroupFilter))
buf := bytes.NewBufferString("")
t.Execute(buf, data)
filter := buf.String()

attributes := []string{
"dn",
}
if loginField != "" {
attributes = append(attributes, loginField)
}

searchRequest := ldap.NewSearchRequest(
la.settings.BaseDN, // The base dn to search
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
filter, // The filter to apply
attributes,
nil,
)

sr, err := l.Search(searchRequest)
if err != nil {
return "", "", err
}

if len(sr.Entries) != 1 {
err := errors.New("LDAP: User `" + username + "' does not exist or too many entries returned")
return "", "", err
}

dn := sr.Entries[0].DN
if loginField != "" {
return dn, sr.Entries[0].GetAttributeValue(loginField), nil
}
return dn, username, nil
}

// Search connect and search for a username in the ldap, add the entry
func (la *LDAPAuth) Search(username string, group string, loginField string) (string, string, error) {
l, err := la.connect()
if err != nil {
return "", "", err
}
defer l.Close()
return la.search(l, username, group, loginField)
}

// VerifyDN connect, and verify the password of dn identified user
func (la *LDAPAuth) VerifyDN(dn string, password string) error {
l, err := la.connect()
if err != nil {
return err
}
defer l.Close()
return l.Bind(dn, password)
}
66 changes: 66 additions & 0 deletions internal/env/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package env

import (
"os"
"regexp"
"strconv"
"strings"
)

// GetEnvString retrieves the value of the environment variable named by the key.
// If the variable is present in the environment the value (which may be empty)
// is returned as string, otherwise the defaultValue is returned.
func GetEnvString(key, defaultValue string) string {
value, ok := os.LookupEnv(key)
if !ok {
return defaultValue
}
return value
}

// GetEnvStringList retrieves the value of the environment variable named by the key.
// If the variable is present in the environment the value (which may be empty)
// is split by ',' or ':' and returned as string List, otherwise the defaultValue
// is returned.
func GetEnvStringList(key string, defaultValue []string) []string {
value, ok := os.LookupEnv(key)
if !ok {
return defaultValue
}
return regexp.MustCompile(`\s*[,:]\s*`).Split(value, -1)
}

// GetEnvInt64 retrieves the value of the environment variable named by the key.
// If the variable is present in the environment the value (which may be empty)
// is returned as int64, otherwise the defaultValue is returned.
func GetEnvInt64(key string, defaultValue int64) int64 {
value, ok := os.LookupEnv(key)
if !ok {
return defaultValue
}
if intValue, err := strconv.ParseInt(value, 0, 64); err == nil {
return intValue
}
return defaultValue
}

// GetEnvInt64 retrieves the value of the environment variable named by the key.
// If the variable is present in the environment the value (which may be empty)
// is returned as bool, otherwise the defaultValue is returned.
// Accepted true value are in ["1", "t", "y", "true", "yes"] in any case.
// Accepted false value are in ["0", "f", "n", "false", "no"] in any case.
func GetEnvBool(key string, defaultValue bool) bool {
value, ok := os.LookupEnv(key)
if !ok {
return defaultValue
}
value = strings.ToLower(value)
switch value {
case "1", "t", "y", "true", "yes":
return true
case "0", "f", "n", "false", "no":
return false
default:
return defaultValue
}
}
Loading

0 comments on commit 367b0f8

Please sign in to comment.