Skip to content

Commit

Permalink
Make TestIntegrations/ReconcileLabels a unit test (gravitational#31124
Browse files Browse the repository at this point in the history
)

* Increase timeout for waiting for label update

* Advance clock more often

* Make TestReconcileLabels a unit test

* Fix imports

* Fix test

* Increase require.Eventually wait time

* Mock control stream
  • Loading branch information
atburke authored Sep 7, 2023
1 parent 5b01122 commit 4cb4ac8
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 80 deletions.
80 changes: 0 additions & 80 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/gravitational/trace/trail"
"github.com/jonboulle/clockwork"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -193,7 +192,6 @@ func TestIntegrations(t *testing.T) {
t.Run("AuthLocalNodeControlStream", suite.bind(testAuthLocalNodeControlStream))
t.Run("AgentlessConnection", suite.bind(testAgentlessConnection))
t.Run("LeafAgentlessConnection", suite.bind(testTrustedClusterAgentless))
t.Run("ReconcileLabels", suite.bind(testReconcileLabels))
}

// testDifferentPinnedIP tests connection is rejected when source IP doesn't match the pinned one
Expand Down Expand Up @@ -7869,84 +7867,6 @@ func testAgentlessConnection(t *testing.T, suite *integrationTestSuite) {
testAgentlessConn(t, tc, node)
}

// testReconcileLabels verifies that an SSH server's labels can be updated by
// upserting a corresponding ServerInfo to the auth server.
func testReconcileLabels(t *testing.T, suite *integrationTestSuite) {
// Create Teleport cluster.
cfg := suite.defaultServiceConfig()
cfg.CachePolicy.Enabled = false
cfg.Proxy.DisableWebService = true
cfg.Proxy.DisableWebInterface = true
cfg.SSH.Labels = map[string]string{"foo": "bar"}
clock := clockwork.NewFakeClock()
cfg.Clock = clock
teleInst := suite.NewTeleportWithConfig(t, nil, nil, cfg)

t.Cleanup(func() { require.NoError(t, teleInst.StopAll()) })

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

require.NoError(t, helpers.WaitForNodeCount(ctx, teleInst, helpers.Site, 1))

authServer := teleInst.Process.GetAuthServer()
servers, err := authServer.GetNodes(ctx, defaults.Namespace)
require.NoError(t, err)
require.Len(t, servers, 1)

server := servers[0]
serverName := server.GetName()
require.Equal(t, map[string]string{"foo": "bar"}, server.GetStaticLabels())
server.SetCloudMetadata(&types.CloudMetadata{
AWS: &types.AWSInfo{
AccountID: "my-account",
InstanceID: "my-instance",
},
})
_, err = authServer.UpsertNode(ctx, server)
require.NoError(t, err)

// Update the server's labels.
labels := map[string]string{"a": "1", "b": "2"}
serverInfo, err := types.NewServerInfo(types.Metadata{
Name: "aws-my-account-my-instance",
Labels: labels,
}, types.ServerInfoSpecV1{})
require.NoError(t, err)
serverInfo.SetSubKind(types.SubKindCloudInfo)
require.NoError(t, authServer.UpsertServerInfo(ctx, serverInfo))

watcher, err := authServer.NewWatcher(ctx, types.Watch{
Kinds: []types.WatchKind{
{
Kind: types.KindNode,
},
},
})
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, watcher.Close()) })

timeout := time.After(5 * time.Second)
// Wait for server to receive updated labels.
for {
clock.Advance(15 * time.Minute)
select {
case <-timeout:
require.Fail(t, "Timed out waiting for server update")
case event := <-watcher.Events():
if event.Type != types.OpPut || event.Resource.GetName() != serverName {
continue
}
if utils.StringMapsEqual(
map[string]string{"foo": "bar", "a": "1", "b": "2"},
event.Resource.GetMetadata().Labels,
) {
return
}
}
}
}

func createAgentlessNode(t *testing.T, authServer *auth.Server, clusterName, nodeHostname string) *types.ServerV2 {
ctx := context.Background()
openSSHCA, err := authServer.GetCertAuthority(ctx, types.CertAuthID{
Expand Down
111 changes: 111 additions & 0 deletions lib/auth/server_info_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
Copyright 2023 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package auth

import (
"context"
"testing"

"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
)

type mockUpstream struct {
client.UpstreamInventoryControlStream
updatedLabels map[string]string
}

func (m *mockUpstream) Send(_ context.Context, msg proto.DownstreamInventoryMessage) error {
if labelMsg, ok := msg.(proto.DownstreamInventoryUpdateLabels); ok {
m.updatedLabels = labelMsg.Labels
}
return nil
}

func (m *mockUpstream) Recv() <-chan proto.UpstreamInventoryMessage {
return make(chan proto.UpstreamInventoryMessage)
}

func (m *mockUpstream) Done() <-chan struct{} {
return make(chan struct{})
}

func (m *mockUpstream) Close() error {
return nil
}

// TestReconcileLabels verifies that an SSH server's labels can be updated by
// upserting a corresponding ServerInfo to the auth server.
func TestReconcileLabels(t *testing.T) {
t.Parallel()

const serverName = "test-server"
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

// Create auth server and fake inventory stream.
clock := clockwork.NewFakeClock()
pack, err := newTestPack(ctx, t.TempDir(), WithClock(clock))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, pack.a.Close())
require.NoError(t, pack.bk.Close())
})
upstream := &mockUpstream{}
t.Cleanup(func() {
require.NoError(t, upstream.Close())
})
require.NoError(t, pack.a.RegisterInventoryControlStream(upstream, proto.UpstreamInventoryHello{
Version: teleport.Version,
ServerID: serverName,
Services: []types.SystemRole{types.RoleNode},
}))

// Create server.
server, err := types.NewServer(serverName, types.KindNode, types.ServerSpecV2{
CloudMetadata: &types.CloudMetadata{
AWS: &types.AWSInfo{
AccountID: "my-account",
InstanceID: "my-instance",
},
},
})
require.NoError(t, err)
_, err = pack.a.UpsertNode(ctx, server)
require.NoError(t, err)

// Update the server's labels.
labels := map[string]string{"a": "1", "b": "2"}
serverInfo, err := types.NewServerInfo(types.Metadata{
Name: "aws-my-account-my-instance",
Labels: labels,
}, types.ServerInfoSpecV1{})
require.NoError(t, err)
serverInfo.SetSubKind(types.SubKindCloudInfo)
require.NoError(t, pack.a.UpsertServerInfo(ctx, serverInfo))

go pack.a.ReconcileServerInfos(ctx)
// Wait until the reconciler finishes processing the serverinfo.
clock.BlockUntil(1)
// Check that labels were received downstream.
require.Equal(t, labels, upstream.updatedLabels)
}

0 comments on commit 4cb4ac8

Please sign in to comment.