Skip to content

Commit

Permalink
Merge branch 'master' into adding-go-modules
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosy committed Jan 2, 2019
2 parents 34a6dbe + cbbf489 commit d2fa6c3
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 43 deletions.
40 changes: 27 additions & 13 deletions pkg/server/endpoints/node/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (h *Handler) Attest(stream node.Node_AttestServer) (err error) {
return errors.New("Error trying to sign CSR")
}

if err := h.updateNodeSelectors(ctx, baseSpiffeIDFromCSR, attestResponse); err != nil {
if err := h.updateNodeSelectors(ctx, baseSpiffeIDFromCSR, attestResponse, request.AttestationData.Type); err != nil {
h.c.Log.Error(err)
return errors.New("Error trying to get selectors for baseSpiffeID")
}
Expand Down Expand Up @@ -498,25 +498,39 @@ func (h *Handler) createAttestationEntry(ctx context.Context,
}

func (h *Handler) updateNodeSelectors(ctx context.Context,
baseSpiffeID string, attestResponse *nodeattestor.AttestResponse) error {

nodeResolver := h.c.Catalog.NodeResolvers()[0]
//Call node resolver plugin to get a map of spiffeID=>Selector
response, err := nodeResolver.Resolve(ctx, &noderesolver.ResolveRequest{
BaseSpiffeIdList: []string{baseSpiffeID},
})
if err != nil {
return err
baseSpiffeID string, attestResponse *nodeattestor.AttestResponse, attestationType string) error {

// Select node resolver based on request attestation type
var nodeResolver noderesolver.NodeResolver
for _, r := range h.c.Catalog.NodeResolvers() {
if r.Config().PluginName == attestationType {
nodeResolver = r
break
}
}

var selectors []*common.Selector
if resolved := response.Map[baseSpiffeID]; resolved != nil {
selectors = append(selectors, resolved.Entries...)
if nodeResolver == nil {
// If not matching node resolver found, skip adding additional selectors
h.c.Log.Debug("could not find node resolver type %q", attestationType)
} else {
//Call node resolver plugin to get a map of spiffeID=>Selector
response, err := nodeResolver.Resolve(ctx, &noderesolver.ResolveRequest{
BaseSpiffeIdList: []string{baseSpiffeID},
})
if err != nil {
return err
}

if resolved := response.Map[baseSpiffeID]; resolved != nil {
selectors = append(selectors, resolved.Entries...)
}
}

selectors = append(selectors, attestResponse.Selectors...)

dataStore := h.c.Catalog.DataStores()[0]
_, err = dataStore.SetNodeSelectors(ctx, &datastore.SetNodeSelectorsRequest{
_, err := dataStore.SetNodeSelectors(ctx, &datastore.SetNodeSelectorsRequest{
Selectors: &datastore.NodeSelectors{
SpiffeId: baseSpiffeID,
Selectors: selectors,
Expand Down
91 changes: 69 additions & 22 deletions pkg/server/endpoints/node/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type HandlerTestSuite struct {
mockNodeResolver *mock_noderesolver.MockNodeResolver
server *mock_node.MockNode_FetchX509SVIDServer
now time.Time
catalog *fakeservercatalog.Catalog
}

func SetupHandlerTest(t *testing.T) *HandlerTestSuite {
Expand All @@ -74,15 +75,14 @@ func SetupHandlerTest(t *testing.T) *HandlerTestSuite {
suite.server = mock_node.NewMockNode_FetchX509SVIDServer(suite.ctrl)
suite.now = time.Now()

catalog := fakeservercatalog.New()
catalog.SetDataStores(suite.mockDataStore)
catalog.SetNodeAttestors(suite.mockNodeAttestor)
catalog.SetNodeResolvers(suite.mockNodeResolver)
suite.catalog = fakeservercatalog.New()
suite.catalog.SetDataStores(suite.mockDataStore)
suite.catalog.SetNodeAttestors(suite.mockNodeAttestor)

suite.handler = NewHandler(HandlerConfig{
Log: log,
Metrics: telemetry.Blackhole{},
Catalog: catalog,
Catalog: suite.catalog,
ServerCA: suite.mockServerCA,
TrustDomain: testTrustDomain,
})
Expand All @@ -93,9 +93,32 @@ func SetupHandlerTest(t *testing.T) *HandlerTestSuite {
return suite
}

func TestAttest(t *testing.T) {
func TestAttestWithMatchingNodeResolver(t *testing.T) {
suite := SetupHandlerTest(t)
defer suite.ctrl.Finish()
suite.catalog.AddNodeResolverNamed("fake_nodeattestor_1", suite.mockNodeResolver)

ctx := peer.NewContext(context.Background(), getFakePeer())
data := getAttestTestData()

stream := mock_node.NewMockNode_AttestServer(suite.ctrl)
stream.EXPECT().Context().Return(ctx).AnyTimes()
stream.EXPECT().Recv().Return(data.request, nil).AnyTimes()

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)
stream.EXPECT().Send(&node.AttestResponse{
SvidUpdate: expected,
}).AnyTimes()

setAttestExpectations(suite, data, true)
suite.NoError(suite.handler.Attest(stream))
suite.Equal(1, suite.limiter.callsFor(AttestMsg))
}

func TestAttestWithNonMatchingNodeResolver(t *testing.T) {
suite := SetupHandlerTest(t)
defer suite.ctrl.Finish()
suite.catalog.AddNodeResolverNamed("non_matching_resolver", suite.mockNodeResolver)

ctx := peer.NewContext(context.Background(), getFakePeer())
data := getAttestTestData()
Expand All @@ -109,21 +132,42 @@ func TestAttest(t *testing.T) {
SvidUpdate: expected,
}).AnyTimes()

setAttestExpectations(suite, data)
setAttestExpectations(suite, data, false)
suite.NoError(suite.handler.Attest(stream))
suite.Equal(1, suite.limiter.callsFor(AttestMsg))
}

func TestAttestWithEmptyNodeResolver(t *testing.T) {
suite := SetupHandlerTest(t)
defer suite.ctrl.Finish()

ctx := peer.NewContext(context.Background(), getFakePeer())
data := getAttestTestData()

stream := mock_node.NewMockNode_AttestServer(suite.ctrl)
stream.EXPECT().Context().Return(ctx).AnyTimes()
stream.EXPECT().Recv().Return(data.request, nil).AnyTimes()

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)
stream.EXPECT().Send(&node.AttestResponse{
SvidUpdate: expected,
}).AnyTimes()

setAttestExpectations(suite, data, false)
suite.NoError(suite.handler.Attest(stream))
suite.Equal(1, suite.limiter.callsFor(AttestMsg))
}
func TestAttestChallengeResponse(t *testing.T) {
suite := SetupHandlerTest(t)
defer suite.ctrl.Finish()
suite.catalog.AddNodeResolverNamed("fake_nodeattestor_1", suite.mockNodeResolver)

data := getAttestTestData()
data.challenges = []challengeResponse{
{challenge: "1+1", response: "2"},
{challenge: "5+7", response: "12"},
}
setAttestExpectations(suite, data)
setAttestExpectations(suite, data, true)

expected := getExpectedAttest(suite, data.baseSpiffeID, data.generatedCert)

Expand Down Expand Up @@ -296,7 +340,7 @@ func getAttestTestData() *fetchBaseSVIDData {
}

func setAttestExpectations(
suite *HandlerTestSuite, data *fetchBaseSVIDData) {
suite *HandlerTestSuite, data *fetchBaseSVIDData, matchingNodeResolver bool) {

stream := mock_nodeattestor.NewMockAttest_Stream(suite.ctrl)
stream.EXPECT().Send(&nodeattestor.AttestRequest{
Expand Down Expand Up @@ -342,23 +386,26 @@ func setAttestExpectations(
}}).
Return(nil, nil)

suite.mockNodeResolver.EXPECT().Resolve(gomock.Any(),
&noderesolver.ResolveRequest{
BaseSpiffeIdList: []string{data.baseSpiffeID},
}).
Return(&noderesolver.ResolveResponse{
Map: data.selectors,
}, nil)
var selectors []*common.Selector

if matchingNodeResolver {
suite.mockNodeResolver.EXPECT().Resolve(gomock.Any(),
&noderesolver.ResolveRequest{
BaseSpiffeIdList: []string{data.baseSpiffeID},
}).
Return(&noderesolver.ResolveResponse{
Map: data.selectors,
}, nil)

selectors = append(selectors, data.selector)
}
selectors = append(selectors, data.attestResponseSelectors[0], data.attestResponseSelectors[1])

suite.mockDataStore.EXPECT().SetNodeSelectors(gomock.Any(),
&datastore.SetNodeSelectorsRequest{
Selectors: &datastore.NodeSelectors{
SpiffeId: data.baseSpiffeID,
Selectors: []*common.Selector{
data.selector,
data.attestResponseSelectors[0],
data.attestResponseSelectors[1],
},
SpiffeId: data.baseSpiffeID,
Selectors: selectors,
},
}).
Return(nil, nil)
Expand Down
24 changes: 16 additions & 8 deletions test/fakes/fakeservercatalog/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,35 @@ func (c *Catalog) DataStores() []*catalog.ManagedDataStore {
func (c *Catalog) SetNodeAttestors(nodeAttestors ...nodeattestor.NodeAttestor) {
c.nodeAttestors = nil
for i, nodeAttestor := range nodeAttestors {
c.nodeAttestors = append(c.nodeAttestors, catalog.NewManagedNodeAttestor(
nodeAttestor, common.PluginConfig{
PluginName: pluginName("nodeattestor", i),
}))
c.AddNodeAttestorNamed(pluginName("nodeattestor", i), nodeAttestor)
}
}

func (c *Catalog) AddNodeAttestorNamed(name string, nodeAttestor nodeattestor.NodeAttestor) {
c.nodeAttestors = append(c.nodeAttestors, catalog.NewManagedNodeAttestor(
nodeAttestor, common.PluginConfig{
PluginName: name,
}))
}

func (c *Catalog) NodeAttestors() []*catalog.ManagedNodeAttestor {
return c.nodeAttestors
}

func (c *Catalog) SetNodeResolvers(nodeResolvers ...noderesolver.NodeResolver) {
c.nodeResolvers = nil
for i, nodeResolver := range nodeResolvers {
c.nodeResolvers = append(c.nodeResolvers, catalog.NewManagedNodeResolver(
nodeResolver, common.PluginConfig{
PluginName: pluginName("noderesolver", i),
}))
c.AddNodeResolverNamed(pluginName("noderesolver", i), nodeResolver)
}
}

func (c *Catalog) AddNodeResolverNamed(name string, nodeResolver noderesolver.NodeResolver) {
c.nodeResolvers = append(c.nodeResolvers, catalog.NewManagedNodeResolver(
nodeResolver, common.PluginConfig{
PluginName: name,
}))
}

func (c *Catalog) NodeResolvers() []*catalog.ManagedNodeResolver {
return c.nodeResolvers
}
Expand Down

0 comments on commit d2fa6c3

Please sign in to comment.