Skip to content

Commit

Permalink
Fix regression with access monitoring rule name vs type matching (gra…
Browse files Browse the repository at this point in the history
…vitational#46294)

* Fix regression with access monitoring rule name vs type matching

* Add test to ensure AMRs are not updated without correct name

* Swap to using require empty in access monitoring tests
  • Loading branch information
EdwardDowling authored Sep 9, 2024
1 parent 8d8df0e commit d73304a
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type RuleHandler struct {

apiClient teleport.Client
pluginType string
pluginName string

fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
}
Expand All @@ -60,6 +61,7 @@ type RuleMap struct {
type RuleHandlerConfig struct {
Client teleport.Client
PluginType string
PluginName string

// FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients.
FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
Expand All @@ -73,6 +75,7 @@ func NewRuleHandler(conf RuleHandlerConfig) *RuleHandler {
},
apiClient: conf.Client,
pluginType: conf.PluginType,
pluginName: conf.PluginName,
fetchRecipientCallback: conf.FetchRecipientCallback,
}
}
Expand Down Expand Up @@ -161,7 +164,7 @@ func (amrh *RuleHandler) getAllAccessMonitoringRules(ctx context.Context) ([]*ac
for {
var page []*accessmonitoringrulesv1.AccessMonitoringRule
var err error
page, nextToken, err = amrh.apiClient.ListAccessMonitoringRulesWithFilter(ctx, defaultAccessMonitoringRulePageSize, nextToken, []string{types.KindAccessRequest}, amrh.pluginType)
page, nextToken, err = amrh.apiClient.ListAccessMonitoringRulesWithFilter(ctx, defaultAccessMonitoringRulePageSize, nextToken, []string{types.KindAccessRequest}, amrh.pluginName)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -187,7 +190,7 @@ func (amrh *RuleHandler) getAccessMonitoringRules() map[string]*accessmonitoring
}

func (amrh *RuleHandler) ruleApplies(amr *accessmonitoringrulesv1.AccessMonitoringRule) bool {
if amr.Spec.Notification.Name != amrh.pluginType {
if amr.Spec.Notification.Name != amrh.pluginName {
return false
}
return slices.ContainsFunc(amr.Spec.Subjects, func(subject string) bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ func mockFetchRecipient(ctx context.Context, recipient string) (*common.Recipien

func TestHandleAccessMonitoringRule(t *testing.T) {
amrh := NewRuleHandler(RuleHandlerConfig{
PluginType: "fakePlugin",
PluginType: "fakePluginType",
PluginName: "fakePluginName",
FetchRecipientCallback: mockFetchRecipient,
})

rule1, err := services.NewAccessMonitoringRuleWithLabels("rule1", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "true",
Notification: &pb.Notification{
Name: "fakePlugin",
Name: "fakePluginName",
Recipients: []string{"a", "b"},
},
})
Expand Down Expand Up @@ -76,3 +77,46 @@ func TestHandleAccessMonitoringRule(t *testing.T) {
})
require.Empty(t, amrh.getAccessMonitoringRules())
}

func TestHandleAccessMonitoringRulePluginNameMisMatch(t *testing.T) {
amrh := NewRuleHandler(RuleHandlerConfig{
PluginName: "fakePluginName",
FetchRecipientCallback: mockFetchRecipient,
})

rule1, err := services.NewAccessMonitoringRuleWithLabels("rule1", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "true",
Notification: &pb.Notification{
Name: "notTheFakePluginName",
Recipients: []string{"a", "b"},
},
})
require.NoError(t, err)
amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpPut,
Resource: types.Resource153ToLegacy(rule1),
})
require.Empty(t, amrh.getAccessMonitoringRules())

rule2, err := services.NewAccessMonitoringRuleWithLabels("rule2", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "true",
Notification: &pb.Notification{
Name: "fakePluginName",
Recipients: []string{"c", "d"},
},
})
require.NoError(t, err)
amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpPut,
Resource: types.Resource153ToLegacy(rule2),
})
require.Len(t, amrh.getAccessMonitoringRules(), 1)

amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpDelete,
Resource: types.Resource153ToLegacy(rule2),
})
require.Empty(t, amrh.getAccessMonitoringRules())
}
1 change: 1 addition & 0 deletions integrations/access/accessrequest/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func (a *App) Init(baseApp *common.BaseApp) error {
a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
Client: a.apiClient,
PluginType: a.pluginType,
PluginName: a.pluginName,
FetchRecipientCallback: a.bot.FetchRecipient,
})

Expand Down
1 change: 1 addition & 0 deletions integrations/access/opsgenie/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func NewOpsgenieApp(ctx context.Context, conf *Config) (*App, error) {
opsgenieApp.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
Client: teleClient,
PluginType: string(conf.BaseConfig.PluginType),
PluginName: pluginName,
FetchRecipientCallback: createScheduleRecipient,
})
opsgenieApp.mainJob = lib.NewServiceJob(opsgenieApp.run)
Expand Down
1 change: 1 addition & 0 deletions integrations/access/pagerduty/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func NewApp(conf Config) (*App, error) {
app.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
Client: conf.Client,
PluginType: types.PluginTypePagerDuty,
PluginName: pluginName,
FetchRecipientCallback: func(_ context.Context, name string) (*common.Recipient, error) {
return &common.Recipient{
Name: name,
Expand Down
1 change: 1 addition & 0 deletions integrations/access/servicenow/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func NewServiceNowApp(ctx context.Context, conf *Config) (*App, error) {
serviceNowApp.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
Client: teleClient,
PluginType: string(conf.PluginType),
PluginName: pluginName,
FetchRecipientCallback: func(_ context.Context, name string) (*common.Recipient, error) {
return &common.Recipient{
Name: name,
Expand Down

0 comments on commit d73304a

Please sign in to comment.