Skip to content

Commit

Permalink
session: allow applying multiple policies
Browse files Browse the repository at this point in the history
And add the extra field to the proto too.
  • Loading branch information
mvdan authored and buger committed Oct 6, 2017
1 parent e2e6bed commit d18d7e4
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 77 deletions.
36 changes: 17 additions & 19 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,23 @@ func GetSpecForOrg(apiID string) *APISpec {
}

func checkAndApplyTrialPeriod(keyName, apiId string, newSession *SessionState) {
// Check the policy to see if we are forcing an expiry on the key
if newSession.ApplyPolicyID == "" {
return
}
policiesMu.RLock()
policy, ok := policiesByID[newSession.ApplyPolicyID]
policiesMu.RUnlock()
if !ok {
return
}
// Are we foring an expiry?
if policy.KeyExpiresIn > 0 {
// We are, does the key exist?
_, found := GetKeyDetail(keyName, apiId)
if !found {
// this is a new key, lets expire it
newSession.Expires = time.Now().Unix() + policy.KeyExpiresIn
// Check the policies to see if we are forcing an expiry on the key
for _, polID := range newSession.PolicyIDs() {
policiesMu.RLock()
policy, ok := policiesByID[polID]
policiesMu.RUnlock()
if !ok {
continue
}
// Are we foring an expiry?
if policy.KeyExpiresIn > 0 {
// We are, does the key exist?
_, found := GetKeyDetail(keyName, apiId)
if !found {
// this is a new key, lets expire it
newSession.Expires = time.Now().Unix() + policy.KeyExpiresIn
}
}

}
}

Expand Down Expand Up @@ -644,7 +642,7 @@ func handleUpdateHashedKey(keyName, apiID, policyId string) (interface{}, int) {

// Set the policy
sess.LastUpdated = strconv.Itoa(int(time.Now().Unix()))
sess.ApplyPolicyID = policyId
sess.SetPolicies(policyId)

sessAsJS, err := json.Marshal(sess)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions coprocess/proto/coprocess_session_state.proto
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ message SessionState {

int64 id_extractor_deadline = 27;
int64 session_lifetime = 28;
repeated string apply_policies = 29;
}
2 changes: 2 additions & 0 deletions coprocess_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func TykSessionState(session *coprocess.SessionState) *SessionState {
session.HmacSecret,
session.IsInactive,
session.ApplyPolicyId,
session.ApplyPolicies,
session.DataExpires,
monitor,
session.EnableDetailedRecording,
Expand Down Expand Up @@ -131,6 +132,7 @@ func ProtoSessionState(session *SessionState) *coprocess.SessionState {
HmacSecret: session.HmacSecret,
IsInactive: session.IsInactive,
ApplyPolicyId: session.ApplyPolicyID,
ApplyPolicies: session.ApplyPolicies,
DataExpires: session.DataExpires,
Monitor: monitor,
EnableDetailedRecording: session.EnableDetailedRecording,
Expand Down
130 changes: 82 additions & 48 deletions middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"fmt"
"net/http"
"strconv"
"time"
Expand Down Expand Up @@ -153,71 +154,104 @@ func (t BaseMiddleware) OrgSessionExpiry(orgid string) int64 {
return cachedVal.(int64)
}

// ApplyPolicyIfExists will check if a policy is loaded, if it is, it will overwrite the session state to use the policy values
func (t BaseMiddleware) ApplyPolicyIfExists(key string, session *SessionState) {
if session.ApplyPolicyID == "" {
return
}
policiesMu.RLock()
policy, ok := policiesByID[session.ApplyPolicyID]
policiesMu.RUnlock()
if !ok {
return
}
// Check ownership, policy org owner must be the same as API,
// otherwise youcould overwrite a session key with a policy from a different org!
if policy.OrgID != t.Spec.OrgID {
log.Error("Attempting to apply policy from different organisation to key, skipping")
return
}
// ApplyPolicies will check if any policies are loaded. If any are, it
// will overwrite the session state to use the policy values.
func (t BaseMiddleware) ApplyPolicies(key string, session *SessionState) {
tags := make(map[string]bool)
didQuota, didRateLimit, didACL := false, false, false
policies := session.PolicyIDs()
for i, polID := range policies {
policiesMu.RLock()
policy, ok := policiesByID[polID]
policiesMu.RUnlock()
if !ok {
return
}
// Check ownership, policy org owner must be the same as API,
// otherwise youcould overwrite a session key with a policy from a different org!
if policy.OrgID != t.Spec.OrgID {
log.Error("Attempting to apply policy from different organisation to key, skipping")
return
}

if policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl {
// This is a partitioned policy, only apply what is active
if policy.Partitions.Quota {
if didQuota {
log.Error("Cannot apply multiple quota policies")
return
}
didQuota = true
// Quotas
session.QuotaMax = policy.QuotaMax
session.QuotaRenewalRate = policy.QuotaRenewalRate
}

if policy.Partitions.RateLimit {
if didRateLimit {
log.Error("Cannot apply multiple rate limit policies")
return
}
didRateLimit = true
// Rate limting
session.Allowance = policy.Rate // This is a legacy thing, merely to make sure output is consistent. Needs to be purged
session.Rate = policy.Rate
session.Per = policy.Per
if policy.LastUpdated != "" {
session.LastUpdated = policy.LastUpdated
}
}

if policy.Partitions.Acl {
// ACL
if !didACL { // first, overwrite rights
session.AccessRights = policy.AccessRights
didACL = true
} else { // second or later, merge
for k, v := range policy.AccessRights {
session.AccessRights[k] = v
}
}
session.HMACEnabled = policy.HMACEnabled
}

if policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl {
// This is a partitioned policy, only apply what is active
if policy.Partitions.Quota {
} else {
if len(policies) > 1 {
log.Error("Cannot apply multiple policies if any are non-partitioned")
return
}
// This is not a partitioned policy, apply everything
// Quotas
session.QuotaMax = policy.QuotaMax
session.QuotaRenewalRate = policy.QuotaRenewalRate
}

if policy.Partitions.RateLimit {
// Rate limting
session.Allowance = policy.Rate // This is a legacy thing, merely to make sure output is consistent. Needs to be purged
session.Rate = policy.Rate
session.Per = policy.Per
if policy.LastUpdated != "" {
session.LastUpdated = policy.LastUpdated
}
}

if policy.Partitions.Acl {
// ACL
session.AccessRights = policy.AccessRights
session.HMACEnabled = policy.HMACEnabled
}

} else {
// This is not a partitioned policy, apply everything
// Quotas
session.QuotaMax = policy.QuotaMax
session.QuotaRenewalRate = policy.QuotaRenewalRate

// Rate limting
session.Allowance = policy.Rate // This is a legacy thing, merely to make sure output is consistent. Needs to be purged
session.Rate = policy.Rate
session.Per = policy.Per
if policy.LastUpdated != "" {
session.LastUpdated = policy.LastUpdated
// Required for all
if i == 0 { // if any is true, key is inactive
session.IsInactive = policy.IsInactive
} else if policy.IsInactive {
session.IsInactive = true
}
for _, tag := range policy.Tags {
tags[tag] = true
}

// ACL
session.AccessRights = policy.AccessRights
session.HMACEnabled = policy.HMACEnabled
}

// Required for all
session.IsInactive = policy.IsInactive
session.Tags = policy.Tags

session.Tags = make([]string, 0, len(tags))
for tag := range tags {
session.Tags = append(session.Tags, tag)
}
// Update the session in the session manager in case it gets called again
t.Spec.SessionManager.UpdateSession(key, session, getLifetime(t.Spec, session))
}
Expand All @@ -233,7 +267,7 @@ func (t BaseMiddleware) CheckSessionAndIdentityForValidKey(key string) (SessionS
if found {
log.Debug("--> Key found in local cache")
session := cachedVal.(SessionState)
t.ApplyPolicyIfExists(key, &session)
t.ApplyPolicies(key, &session)
return session, true
}
}
Expand All @@ -247,7 +281,7 @@ func (t BaseMiddleware) CheckSessionAndIdentityForValidKey(key string) (SessionS
go SessionCache.Set(key, session, cache.DefaultExpiration)

// Check for a policy, if there is a policy, pull it and overwrite the session values
t.ApplyPolicyIfExists(key, &session)
t.ApplyPolicies(key, &session)
log.Debug("--> Got key")
return session, true
}
Expand All @@ -263,7 +297,7 @@ func (t BaseMiddleware) CheckSessionAndIdentityForValidKey(key string) (SessionS
go SessionCache.Set(key, session, cache.DefaultExpiration)

// Check for a policy, if there is a policy, pull it and overwrite the session values
t.ApplyPolicyIfExists(key, &session)
t.ApplyPolicies(key, &session)

log.Debug("Lifetime is: ", getLifetime(t.Spec, &session))
// Need to set this in order for the write to work!
Expand Down
7 changes: 4 additions & 3 deletions mw_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,13 @@ func (k *JWTMiddleware) getBasePolicyID(token *jwt.Token) (string, bool) {
return "", false
}

if clientSession.ApplyPolicyID == "" {
pols := clientSession.PolicyIDs()
if len(pols) < 1 {
return "", false
}

// Use the policy from the client ID
return clientSession.ApplyPolicyID, true
return pols[0], true
}

return "", false
Expand Down Expand Up @@ -440,7 +441,7 @@ func generateSessionFromPolicy(policyID, orgID string, enforceOrg bool) (Session
orgID = policy.OrgID
}

session.ApplyPolicyID = policyID
session.SetPolicies(policyID)
session.OrgID = orgID
session.Allowance = policy.Rate // This is a legacy thing, merely to make sure output is consistent. Needs to be purged
session.Rate = policy.Rate
Expand Down
2 changes: 1 addition & 1 deletion mw_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func createJWTSessionWithRSA() *SessionState {

func createJWTSessionWithRSAWithPolicy() *SessionState {
session := createJWTSessionWithRSA()
session.ApplyPolicyID = "987654321"
session.SetPolicies("987654321")
return session
}

Expand Down
3 changes: 2 additions & 1 deletion oauth_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/json"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -282,7 +283,7 @@ func TestAPIClientAuthorizeTokenWithPolicy(t *testing.T) {
t.Error("Key was not created (Can't find it)!")
}

if session.ApplyPolicyID != "TEST-4321" {
if !reflect.DeepEqual(session.PolicyIDs(), []string{"TEST-4321"}) {
t.Error("Policy not added to token!")
}
}
Expand Down
29 changes: 24 additions & 5 deletions session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ type SessionState struct {
JWTData struct {
Secret string `json:"secret" msg:"secret"`
} `json:"jwt_data" msg:"jwt_data"`
HMACEnabled bool `json:"hmac_enabled" msg:"hmac_enabled"`
HmacSecret string `json:"hmac_string" msg:"hmac_string"`
IsInactive bool `json:"is_inactive" msg:"is_inactive"`
ApplyPolicyID string `json:"apply_policy_id" msg:"apply_policy_id"`
DataExpires int64 `json:"data_expires" msg:"data_expires"`
HMACEnabled bool `json:"hmac_enabled" msg:"hmac_enabled"`
HmacSecret string `json:"hmac_string" msg:"hmac_string"`
IsInactive bool `json:"is_inactive" msg:"is_inactive"`
ApplyPolicyID string `json:"apply_policy_id" msg:"apply_policy_id"`
ApplyPolicies []string `json:"apply_policies" msg:"apply_policies"`
DataExpires int64 `json:"data_expires" msg:"data_expires"`
Monitor struct {
TriggerLimits []float64 `json:"trigger_limits" msg:"trigger_limits"`
} `json:"monitor" msg:"monitor"`
Expand Down Expand Up @@ -116,3 +117,21 @@ func getLifetime(spec *APISpec, session *SessionState) int64 {
}
return 0
}

// PolicyIDs returns the IDs of all the policies applied to this
// session. For backwards compatibility reasons, this falls back to
// ApplyPolicyID if ApplyPolicies is empty.
func (s *SessionState) PolicyIDs() []string {
if len(s.ApplyPolicies) > 0 {
return s.ApplyPolicies
}
if s.ApplyPolicyID != "" {
return []string{s.ApplyPolicyID}
}
return nil
}

func (s *SessionState) SetPolicies(ids ...string) {
s.ApplyPolicyID = ""
s.ApplyPolicies = ids
}

0 comments on commit d18d7e4

Please sign in to comment.