Skip to content

Commit

Permalink
Properly check for SAMLIdPServiceProvider access (gravitational#33190)
Browse files Browse the repository at this point in the history
* Properly check for SAMLIdPServiceProvider access

* Remove unneeded debug log
  • Loading branch information
avatus authored Oct 11, 2023
1 parent b905eaf commit cea7a60
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
31 changes: 11 additions & 20 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -1505,23 +1505,10 @@ func (a *ServerWithRoles) GetNode(ctx context.Context, namespace, name string) (
func (a *ServerWithRoles) ListUnifiedResources(ctx context.Context, req *proto.ListUnifiedResourcesRequest) (*proto.ListUnifiedResourcesResponse, error) {
// Fetch full list of resources in the backend.
var (
elapsedFetch time.Duration
elapsedFilter time.Duration
unifiedResources types.ResourcesWithLabels
filteredResources types.ResourcesWithLabels
nextKey string
unifiedResources types.ResourcesWithLabels
nextKey string
)

defer func() {
log.WithFields(logrus.Fields{
"user": a.context.User.GetName(),
"elapsed_fetch": elapsedFetch,
"elapsed_filter": elapsedFilter,
}).Debugf(
"ListUnifiedResources(%v->%v) in %v.",
len(unifiedResources), len(filteredResources), elapsedFetch+elapsedFilter)
}()

filter := services.MatchResourceFilter{
Labels: req.Labels,
SearchKeywords: req.SearchKeywords,
Expand All @@ -1533,8 +1520,6 @@ func (a *ServerWithRoles) ListUnifiedResources(ctx context.Context, req *proto.L
if err != nil {
return nil, trace.Wrap(err)
}
startFetch := time.Now()
startFilter := time.Now()
if req.PinnedOnly {
prefs, err := a.authServer.GetUserPreferences(ctx, a.context.User.GetName())
if err != nil {
Expand All @@ -1555,7 +1540,15 @@ func (a *ServerWithRoles) ListUnifiedResources(ctx context.Context, req *proto.L
}
} else {
unifiedResources, nextKey, err = a.authServer.UnifiedResourceCache.IterateUnifiedResources(ctx, func(resource types.ResourceWithLabels) (bool, error) {
if err := resourceChecker.CanAccess(resource); err != nil {
var err error
switch r := resource.(type) {
case types.SAMLIdPServiceProvider:
err = a.action(apidefaults.Namespace, types.KindSAMLIdPServiceProvider, types.VerbList)
default:
err = resourceChecker.CanAccess(r)
}

if err != nil {
if trace.IsAccessDenied(err) {
return false, nil
}
Expand All @@ -1568,8 +1561,6 @@ func (a *ServerWithRoles) ListUnifiedResources(ctx context.Context, req *proto.L
return nil, trace.Wrap(err, "filtering unified resources")
}
}
elapsedFetch = time.Since(startFetch)
elapsedFilter = time.Since(startFilter)

paginatedResources, err := services.MakePaginatedResources(types.KindUnifiedResource, unifiedResources)
if err != nil {
Expand Down
21 changes: 17 additions & 4 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4273,7 +4273,7 @@ func TestListUnifiedResources_WithSearch(t *testing.T) {
t.Parallel()
ctx := context.Background()
srv := newTestTLSServer(t)
names := []string{"tifa", "cloud", "aerith", "baret", "cid", "tifa2"}
names := []string{"vivi", "cloud", "aerith", "barret", "cid", "vivi2"}
for i := 0; i < 6; i++ {
name := names[i]
node, err := types.NewServerWithLabels(
Expand All @@ -4293,6 +4293,19 @@ func TestListUnifiedResources_WithSearch(t *testing.T) {
require.NoError(t, err)
require.Len(t, testNodes, 6)

sp := &types.SAMLIdPServiceProviderV1{
ResourceHeader: types.ResourceHeader{
Metadata: types.Metadata{
Name: "tifaSAML",
},
},
Spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newEntityDescriptor("tifaSAML"),
EntityID: "tifaSAML",
},
}
require.NoError(t, srv.Auth().CreateSAMLIdPServiceProvider(ctx, sp))

// create user and client
user, _, err := CreateUserAndRole(srv.Auth(), "user", nil, nil)
require.NoError(t, err)
Expand All @@ -4304,13 +4317,13 @@ func TestListUnifiedResources_WithSearch(t *testing.T) {
SortBy: types.SortBy{IsDesc: true, Field: types.ResourceMetadataName},
})
require.NoError(t, err)
require.Len(t, resp.Resources, 2)
require.Len(t, resp.Resources, 1)
require.Empty(t, resp.NextKey)

// Check that our returned resource has the correct name
for _, resource := range resp.Resources {
r := resource.GetNode()
require.True(t, strings.Contains(r.GetHostname(), "tifa"))
r := resource.GetAppServerOrSAMLIdPServiceProvider()
require.True(t, strings.Contains(r.GetName(), "tifa"))
}
}

Expand Down

0 comments on commit cea7a60

Please sign in to comment.