diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go new file mode 100644 index 0000000000000..2d6c1a72005c2 --- /dev/null +++ b/integrations/access/accessmonitoring/access_monitoring_rules.go @@ -0,0 +1,196 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package accessmonitoring + +import ( + "context" + "maps" + "slices" + "sync" + + "github.com/gravitational/trace" + + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/common" + "github.com/gravitational/teleport/integrations/access/common/teleport" + "github.com/gravitational/teleport/integrations/lib/logger" +) + +const ( + // defaultAccessMonitoringRulePageSize is the default number of rules to retrieve per request + defaultAccessMonitoringRulePageSize = 1000 +) + +// RuleHandler stores a cache of Access Monitoring Rules for use with Access Request routing in plugins. +// Must be initialized by calling InitAccessMonitoringRulesCache, a watcher on Acccess Monitoring Rules must pass in new rules using HandleAccessMonitoringRule. +type RuleHandler struct { + accessMonitoringRules RuleMap + + apiClient teleport.Client + pluginType string + + fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) +} + +// RuleMap is a concurrent map for access monitoring rules. +type RuleMap struct { + sync.RWMutex + // rules are the access monitoring rules being stored. + rules map[string]*accessmonitoringrulesv1.AccessMonitoringRule +} + +// RuleHandlerConfig stores the configuration for RuleHandler +type RuleHandlerConfig struct { + Client teleport.Client + PluginType string + + // FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients. + FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) +} + +// NewRuleHandler returns a new RuleHandler. +func NewRuleHandler(conf RuleHandlerConfig) *RuleHandler { + return &RuleHandler{ + accessMonitoringRules: RuleMap{ + rules: make(map[string]*accessmonitoringrulesv1.AccessMonitoringRule), + }, + apiClient: conf.Client, + pluginType: conf.PluginType, + fetchRecipientCallback: conf.FetchRecipientCallback, + } +} + +// InitAccessMonitoringRulesCache initializes the cache of Access Monitoring Rules. +func (amrh *RuleHandler) InitAccessMonitoringRulesCache(ctx context.Context) error { + accessMonitoringRules, err := amrh.getAllAccessMonitoringRules(ctx) + if err != nil { + return trace.Wrap(err) + } + amrh.accessMonitoringRules.Lock() + defer amrh.accessMonitoringRules.Unlock() + for _, amr := range accessMonitoringRules { + if !amrh.ruleApplies(amr) { + continue + } + amrh.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr + } + return nil +} + +// HandleAccessMonitoringRule checks if a new rule should be stored in the cache and updates accordingly. +// Also removes deleted rules from the cache. +func (amrh *RuleHandler) HandleAccessMonitoringRule(ctx context.Context, event types.Event) error { + if kind := event.Resource.GetKind(); kind != types.KindAccessMonitoringRule { + return trace.BadParameter("expected %s resource kind, got %s", types.KindAccessMonitoringRule, kind) + } + + amrh.accessMonitoringRules.Lock() + defer amrh.accessMonitoringRules.Unlock() + switch op := event.Type; op { + case types.OpPut: + e, ok := event.Resource.(types.Resource153Unwrapper) + if !ok { + return trace.BadParameter("expected Resource153Unwrapper resource type, got %T", event.Resource) + } + req, ok := e.Unwrap().(*accessmonitoringrulesv1.AccessMonitoringRule) + if !ok { + return trace.BadParameter("expected AccessMonitoringRule resource type, got %T", event.Resource) + } + + // In the event an existing rule no longer applies we must remove it. + if !amrh.ruleApplies(req) { + delete(amrh.accessMonitoringRules.rules, event.Resource.GetName()) + return nil + } + amrh.accessMonitoringRules.rules[req.Metadata.Name] = req + return nil + case types.OpDelete: + delete(amrh.accessMonitoringRules.rules, event.Resource.GetName()) + return nil + default: + return trace.BadParameter("unexpected event operation %s", op) + } +} + +// RecipientsFromAccessMonitoringRules returns the recipients that result from the Access Monitoring Rules being applied to the given Access Request. +func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context, req types.AccessRequest) *common.RecipientSet { + log := logger.Get(ctx) + recipientSet := common.NewRecipientSet() + + for _, rule := range amrh.getAccessMonitoringRules() { + match, err := MatchAccessRequest(rule.Spec.Condition, req) + if err != nil { + log.WithError(err).WithField("rule", rule.Metadata.Name). + Warn("Failed to parse access monitoring notification rule") + } + if !match { + continue + } + for _, recipient := range rule.Spec.Notification.Recipients { + rec, err := amrh.fetchRecipientCallback(ctx, recipient) + if err != nil { + log.WithError(err).Warn("Failed to fetch plugin recipients based on Access moniotring rule recipients") + continue + } + recipientSet.Add(*rec) + } + } + return &recipientSet +} + +func (amrh *RuleHandler) getAllAccessMonitoringRules(ctx context.Context) ([]*accessmonitoringrulesv1.AccessMonitoringRule, error) { + var resources []*accessmonitoringrulesv1.AccessMonitoringRule + var nextToken string + for { + var page []*accessmonitoringrulesv1.AccessMonitoringRule + var err error + page, nextToken, err = amrh.apiClient.ListAccessMonitoringRulesWithFilter(ctx, defaultAccessMonitoringRulePageSize, nextToken, []string{types.KindAccessRequest}, amrh.pluginType) + if err != nil { + return nil, trace.Wrap(err) + } + + for _, amr := range page { + if !amrh.ruleApplies(amr) { + continue + } + resources = append(resources, amr) + } + + if nextToken == "" { + break + } + } + return resources, nil +} + +func (amrh *RuleHandler) getAccessMonitoringRules() map[string]*accessmonitoringrulesv1.AccessMonitoringRule { + amrh.accessMonitoringRules.RLock() + defer amrh.accessMonitoringRules.RUnlock() + return maps.Clone(amrh.accessMonitoringRules.rules) +} + +func (amrh *RuleHandler) ruleApplies(amr *accessmonitoringrulesv1.AccessMonitoringRule) bool { + if amr.Spec.Notification.Name != amrh.pluginType { + return false + } + return slices.ContainsFunc(amr.Spec.Subjects, func(subject string) bool { + return subject == types.KindAccessRequest + }) +} diff --git a/integrations/access/accessmonitoring/access_monitoring_rules_test.go b/integrations/access/accessmonitoring/access_monitoring_rules_test.go new file mode 100644 index 0000000000000..9069cf08e536c --- /dev/null +++ b/integrations/access/accessmonitoring/access_monitoring_rules_test.go @@ -0,0 +1,78 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package accessmonitoring + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/common" + "github.com/gravitational/teleport/lib/services" +) + +func mockFetchRecipient(ctx context.Context, recipient string) (*common.Recipient, error) { + return nil, nil +} + +func TestHandleAccessMonitoringRule(t *testing.T) { + amrh := NewRuleHandler(RuleHandlerConfig{ + PluginType: "fakePlugin", + FetchRecipientCallback: mockFetchRecipient, + }) + + rule1, err := services.NewAccessMonitoringRuleWithLabels("rule1", nil, &pb.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "true", + Notification: &pb.Notification{ + Name: "fakePlugin", + Recipients: []string{"a", "b"}, + }, + }) + require.NoError(t, err) + amrh.HandleAccessMonitoringRule(context.Background(), types.Event{ + Type: types.OpPut, + Resource: types.Resource153ToLegacy(rule1), + }) + require.Len(t, amrh.getAccessMonitoringRules(), 1) + + rule2, err := services.NewAccessMonitoringRuleWithLabels("rule2", nil, &pb.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "true", + Notification: &pb.Notification{ + Name: "aDifferentFakePlugin", + Recipients: []string{"a", "b"}, + }, + }) + require.NoError(t, err) + amrh.HandleAccessMonitoringRule(context.Background(), types.Event{ + Type: types.OpPut, + Resource: types.Resource153ToLegacy(rule2), + }) + require.Len(t, amrh.getAccessMonitoringRules(), 1) + + amrh.HandleAccessMonitoringRule(context.Background(), types.Event{ + Type: types.OpDelete, + Resource: types.Resource153ToLegacy(rule1), + }) + require.Empty(t, amrh.getAccessMonitoringRules()) +} diff --git a/integrations/access/accessrequest/request_mapping.go b/integrations/access/accessmonitoring/request_mapping.go similarity index 97% rename from integrations/access/accessrequest/request_mapping.go rename to integrations/access/accessmonitoring/request_mapping.go index 7bc46121d0c0e..fb50831d14e45 100644 --- a/integrations/access/accessrequest/request_mapping.go +++ b/integrations/access/accessmonitoring/request_mapping.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package accessrequest +package accessmonitoring import ( "time" @@ -88,7 +88,7 @@ func newRequestConditionParser() (*typical.Parser[accessRequestExpressionEnv, an return requestConditionParser, nil } -func matchAccessRequest(expr string, req types.AccessRequest) (bool, error) { +func MatchAccessRequest(expr string, req types.AccessRequest) (bool, error) { parsedExpr, err := parseAccessRequestExpression(expr) if err != nil { return false, trace.Wrap(err) diff --git a/integrations/access/accessrequest/app.go b/integrations/access/accessrequest/app.go index 75d11a3b52530..9d32482824fba 100644 --- a/integrations/access/accessrequest/app.go +++ b/integrations/access/accessrequest/app.go @@ -21,16 +21,14 @@ package accessrequest import ( "context" "fmt" - "maps" "slices" - "sync" "time" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/accessrequest" - accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/accessmonitoring" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/common/teleport" "github.com/gravitational/teleport/integrations/lib" @@ -42,8 +40,6 @@ import ( const ( // handlerTimeout is used to bound the execution time of watcher event handler. handlerTimeout = time.Second * 5 - // defaultAccessMonitoringRulePageSize is the default number of rules to retrieve per request. - defaultAccessMonitoringRulePageSize = 10 ) // App is the access request application for plugins. This will notify when access requests @@ -57,19 +53,12 @@ type App struct { bot MessagingBot job lib.ServiceJob - accessMonitoringRules amrMap -} - -type amrMap struct { - sync.RWMutex - rules map[string]*accessmonitoringrulesv1.AccessMonitoringRule + accessMonitoringRules *accessmonitoring.RuleHandler } // NewApp will create a new access request application. func NewApp(bot MessagingBot) common.App { - app := &App{accessMonitoringRules: amrMap{ - rules: make(map[string]*accessmonitoringrulesv1.AccessMonitoringRule), - }} + app := &App{} app.job = lib.NewServiceJob(app.run) return app } @@ -93,6 +82,12 @@ func (a *App) Init(baseApp *common.BaseApp) error { return trace.BadParameter("bot does not implement access request bot methods") } + a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{ + Client: a.apiClient, + PluginType: a.pluginType, + FetchRecipientCallback: a.bot.FetchRecipient, + }) + return nil } @@ -159,7 +154,7 @@ func (a *App) run(ctx context.Context) error { // Check if KindAccessMonitoringRule resources are being watched, // the role the plugin is running as may not have access. if slices.Contains(acceptedWatchKinds, types.KindAccessMonitoringRule) { - if err := a.initAccessMonitoringRulesCache(ctx); err != nil { + if err := a.accessMonitoringRules.InitAccessMonitoringRulesCache(ctx); err != nil { return trace.Wrap(err, "initializing Access Monitoring Rule cache") } } @@ -173,21 +168,12 @@ func (a *App) run(ctx context.Context) error { return nil } -func (a *App) amrAppliesToThisPlugin(amr *accessmonitoringrulesv1.AccessMonitoringRule) bool { - if amr.Spec.Notification.Name != a.pluginName { - return false - } - return slices.ContainsFunc(amr.Spec.Subjects, func(subject string) bool { - return subject == types.KindAccessRequest - }) -} - // onWatcherEvent is called for every cluster Event. It will filter out non-access-request events and // call onPendingRequest, onResolvedRequest and on DeletedRequest depending on the event. func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { switch event.Resource.GetKind() { case types.KindAccessMonitoringRule: - return trace.Wrap(a.handleAccessMonitoringRule(ctx, event)) + return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event)) case types.KindAccessRequest: return trace.Wrap(a.handleAcessRequest(ctx, event)) } @@ -238,39 +224,6 @@ func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error { } } -func (a *App) handleAccessMonitoringRule(ctx context.Context, event types.Event) error { - if kind := event.Resource.GetKind(); kind != types.KindAccessMonitoringRule { - return trace.BadParameter("expected %s resource kind, got %s", types.KindAccessMonitoringRule, kind) - } - - a.accessMonitoringRules.Lock() - defer a.accessMonitoringRules.Unlock() - switch op := event.Type; op { - case types.OpPut: - e, ok := event.Resource.(types.Resource153Unwrapper) - if !ok { - return trace.BadParameter("expected Resource153Unwrapper resource type, got %T", event.Resource) - } - req, ok := e.Unwrap().(*accessmonitoringrulesv1.AccessMonitoringRule) - if !ok { - return trace.BadParameter("expected AccessMonitoringRule resource type, got %T", event.Resource) - } - - // In the event an existing rule no longer applies we must remove it. - if !a.amrAppliesToThisPlugin(req) { - delete(a.accessMonitoringRules.rules, event.Resource.GetName()) - return nil - } - a.accessMonitoringRules.rules[req.Metadata.Name] = req - return nil - case types.OpDelete: - delete(a.accessMonitoringRules.rules, event.Resource.GetName()) - return nil - default: - return trace.BadParameter("unexpected event operation %s", op) - } -} - func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error { log := logger.Get(ctx) @@ -431,7 +384,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) // This can happen if this set contains the channel `C` and the email for channel `C`. recipientSet := common.NewRecipientSet() - recipients := a.recipientsFromAccessMonitoringRules(ctx, req) + recipients := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req) recipients.ForEach(func(r common.Recipient) { recipientSet.Add(r) }) @@ -483,40 +436,6 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) return recipientSet.ToSlice() } -func (a *App) recipientsFromAccessMonitoringRules(ctx context.Context, req types.AccessRequest) *common.RecipientSet { - log := logger.Get(ctx) - recipientSet := common.NewRecipientSet() - - // This switch is used to determine which plugins we are enabling access monitoring notification rules for. - switch a.pluginType { - // Enabled plugins are added to this case. - case types.PluginTypeSlack, types.PluginTypeMattermost: - log.Debug("Applying access monitoring rules to request") - default: - return &recipientSet - } - - for _, rule := range a.getAccessMonitoringRules() { - match, err := matchAccessRequest(rule.Spec.Condition, req) - if err != nil { - log.WithError(err).WithField("rule", rule.Metadata.Name). - Warn("Failed to parse access monitoring notification rule") - } - if !match { - continue - } - for _, recipient := range rule.Spec.Notification.Recipients { - rec, err := a.bot.FetchRecipient(ctx, recipient) - if err != nil { - log.WithError(err).Warn("Failed to fetch plugin recipients based on Access moniotring rule recipients") - continue - } - recipientSet.Add(*rec) - } - } - return &recipientSet -} - // updateMessages updates the messages status and adds the resolve reason. func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.ResolutionTag, reason string, reviews []types.AccessReview) error { log := logger.Get(ctx) @@ -589,45 +508,3 @@ func (a *App) getResourceNames(ctx context.Context, req types.AccessRequest) ([] } return resourceNames, nil } - -func (a *App) initAccessMonitoringRulesCache(ctx context.Context) error { - accessMonitoringRules, err := a.getAllAccessMonitoringRules(ctx) - if err != nil { - return trace.Wrap(err) - } - a.accessMonitoringRules.Lock() - defer a.accessMonitoringRules.Unlock() - for _, amr := range accessMonitoringRules { - if !a.amrAppliesToThisPlugin(amr) { - continue - } - a.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr - } - return nil -} - -func (a *App) getAllAccessMonitoringRules(ctx context.Context) ([]*accessmonitoringrulesv1.AccessMonitoringRule, error) { - var resources []*accessmonitoringrulesv1.AccessMonitoringRule - var nextToken string - for { - var page []*accessmonitoringrulesv1.AccessMonitoringRule - var err error - page, nextToken, err = a.apiClient.ListAccessMonitoringRulesWithFilter(ctx, defaultAccessMonitoringRulePageSize, nextToken, []string{types.KindAccessRequest}, a.pluginName) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, page...) - - if nextToken == "" { - break - } - } - return resources, nil -} - -func (a *App) getAccessMonitoringRules() map[string]*accessmonitoringrulesv1.AccessMonitoringRule { - a.accessMonitoringRules.RLock() - defer a.accessMonitoringRules.RUnlock() - return maps.Clone(a.accessMonitoringRules.rules) -} diff --git a/integrations/access/common/recipient.go b/integrations/access/common/recipient.go index 67273cefa4b5b..0e2a7a2d21409 100644 --- a/integrations/access/common/recipient.go +++ b/integrations/access/common/recipient.go @@ -28,6 +28,8 @@ import ( const ( // RecipientKindSchedule shows a recipient is a schedule. RecipientKindSchedule = "schedule" + // RecipientKindTeam shows a recipient is a team. + RecipientKindTeam = "team" ) // RawRecipientsMap is a mapping of roles to recipient(s). diff --git a/integrations/access/opsgenie/app.go b/integrations/access/opsgenie/app.go index 74b13469ac351..8ab196ee2a254 100644 --- a/integrations/access/opsgenie/app.go +++ b/integrations/access/opsgenie/app.go @@ -31,6 +31,7 @@ import ( tp "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/accessmonitoring" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/common/teleport" "github.com/gravitational/teleport/integrations/lib" @@ -67,6 +68,8 @@ type App struct { opsgenie *Client mainJob lib.ServiceJob conf Config + + accessMonitoringRules *accessmonitoring.RuleHandler } // NewOpsgenieApp initializes a new teleport-opsgenie app and returns it. @@ -75,6 +78,15 @@ func NewOpsgenieApp(ctx context.Context, conf *Config) (*App, error) { PluginName: pluginName, conf: *conf, } + teleClient, err := conf.GetTeleportClient(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + opsgenieApp.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{ + Client: teleClient, + PluginType: string(conf.BaseConfig.PluginType), + FetchRecipientCallback: createScheduleRecipient, + }) opsgenieApp.mainJob = lib.NewServiceJob(opsgenieApp.run) return opsgenieApp, nil } @@ -111,7 +123,10 @@ func (a *App) run(ctx context.Context) error { watcherJob, err := watcherjob.NewJob( a.teleport, watcherjob.Config{ - Watch: types.Watch{Kinds: []types.WatchKind{types.WatchKind{Kind: types.KindAccessRequest}}}, + Watch: types.Watch{Kinds: []types.WatchKind{ + {Kind: types.KindAccessRequest}, + {Kind: types.KindAccessMonitoringRule}, + }}, EventFuncTimeout: handlerTimeout, }, a.onWatcherEvent, @@ -125,6 +140,10 @@ func (a *App) run(ctx context.Context) error { return trace.Wrap(err) } + if err := a.accessMonitoringRules.InitAccessMonitoringRulesCache(ctx); err != nil { + return trace.Wrap(err) + } + a.mainJob.SetReady(ok) if ok { log.Info("Plugin is ready") @@ -181,7 +200,19 @@ func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, err return pong, trace.Wrap(err) } +// onWatcherEvent is called for every cluster Event. It will call the handlers +// for access request and access monitoring rule events. func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { + switch event.Resource.GetKind() { + case types.KindAccessMonitoringRule: + return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event)) + case types.KindAccessRequest: + return trace.Wrap(a.handleAcessRequest(ctx, event)) + } + return trace.BadParameter("unexpected kind %s", event.Resource.GetKind()) +} + +func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error { if kind := event.Resource.GetKind(); kind != types.KindAccessRequest { return trace.Errorf("unexpected kind %s", kind) } @@ -229,11 +260,6 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { } func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error { - if len(req.GetSystemAnnotations()) == 0 { - logger.Get(ctx).Debug("Cannot proceed further. Request is missing any annotations") - return nil - } - // First, try to create a notification alert. isNew, notifyErr := a.tryNotifyService(ctx, req) @@ -275,34 +301,28 @@ func (a *App) onDeletedRequest(ctx context.Context, reqID string) error { return a.resolveAlert(ctx, reqID, Resolution{Tag: ResolvedExpired}) } -// Get services to notify from both annotations: /notify-services and /teams -// Return error if both are empty -func (a *App) getNotifyServiceNames(ctx context.Context, req types.AccessRequest) ([]string, error) { +// getNotifySchedulesAndTeams get schedules and teams to notify from both +// annotations: /notify-services and /teams, returns an error if both are empty. +func (a *App) getNotifySchedulesAndTeams(ctx context.Context, req types.AccessRequest) (schedules []string, teams []string, err error) { log := logger.Get(ctx) - var servicesNames []string - scheduleAnnotationKey := types.TeleportNamespace + types.ReqAnnotationNotifySchedulesLabel - schedules, err := common.GetServiceNamesFromAnnotations(req, scheduleAnnotationKey) + schedules, err = common.GetServiceNamesFromAnnotations(req, scheduleAnnotationKey) if err != nil { log.Debugf("No schedules to notifiy in %s", scheduleAnnotationKey) - } else { - servicesNames = append(servicesNames, schedules...) } teamAnnotationKey := types.TeleportNamespace + types.ReqAnnotationTeamsLabel - teams, err := common.GetServiceNamesFromAnnotations(req, teamAnnotationKey) + teams, err = common.GetServiceNamesFromAnnotations(req, teamAnnotationKey) if err != nil { log.Debugf("No teams to notifiy in %s", teamAnnotationKey) - } else { - servicesNames = append(servicesNames, teams...) } - if len(servicesNames) == 0 { - return nil, trace.NotFound("no services to notify") + if len(schedules) == 0 && len(teams) == 0 { + return nil, nil, trace.NotFound("no schedules or teams to notify") } - return servicesNames, nil + return schedules, teams, nil } func (a *App) getOnCallServiceNames(req types.AccessRequest) ([]string, error) { @@ -313,8 +333,8 @@ func (a *App) getOnCallServiceNames(req types.AccessRequest) ([]string, error) { func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bool, error) { log := logger.Get(ctx) - serviceNames, err := a.getNotifyServiceNames(ctx, req) - if err != nil || len(serviceNames) == 0 { + recipientSchedules, recipientTeams, err := a.getMessageRecipients(ctx, req) + if err != nil { log.Debugf("Skipping the notification: %s", err) return false, trace.Wrap(errMissingAnnotation) } @@ -324,6 +344,22 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo for k, v := range req.GetSystemAnnotations() { annotations[k] = v } + + if len(recipientTeams) != 0 { + teams := make([]string, 0, len(recipientTeams)) + for _, t := range recipientTeams { + teams = append(teams, t.Name) + } + annotations[types.TeleportNamespace+types.ReqAnnotationTeamsLabel] = teams + } + if len(recipientSchedules) != 0 { + schedules := make([]string, 0, len(recipientSchedules)) + for _, s := range recipientSchedules { + schedules = append(schedules, s.Name) + } + annotations[types.TeleportNamespace+types.ReqAnnotationNotifySchedulesLabel] = schedules + } + reqData := RequestData{ User: req.GetUser(), Roles: req.GetRoles(), @@ -357,6 +393,40 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo return isNew, nil } +func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) ([]common.Recipient, []common.Recipient, error) { + recipientSetSchedules := common.NewRecipientSet() + recipientSchedules := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req) + recipientSchedules.ForEach(func(r common.Recipient) { + recipientSetSchedules.Add(r) + }) + // Access Monitoring Rules recipients does not have a way to handle separate recipient types currently. + // Recipients from Access Monitoring Rules will be schedules only currently. + if recipientSetSchedules.Len() != 0 { + return recipientSetSchedules.ToSlice(), nil, nil + } + rawSchedules, rawTeams, err := a.getNotifySchedulesAndTeams(ctx, req) + if err != nil { + return nil, nil, trace.Wrap(err) + } + for _, rawSchedule := range rawSchedules { + recipientSetSchedules.Add(common.Recipient{ + Name: rawSchedule, + ID: rawSchedule, + Kind: common.RecipientKindSchedule, + }) + } + + recipientSetTeams := common.NewRecipientSet() + for _, rawTeam := range rawTeams { + recipientSetTeams.Add(common.Recipient{ + Name: rawTeam, + ID: rawTeam, + Kind: common.RecipientKindTeam, + }) + } + return recipientSetSchedules.ToSlice(), nil, nil +} + // createAlert posts an alert with request information. func (a *App) createAlert(ctx context.Context, reqID string, reqData RequestData) error { data, err := a.opsgenie.CreateAlert(ctx, reqID, reqData) diff --git a/integrations/access/opsgenie/bot.go b/integrations/access/opsgenie/bot.go index 6eb7a474ad3e5..745d34fb6c549 100644 --- a/integrations/access/opsgenie/bot.go +++ b/integrations/access/opsgenie/bot.go @@ -129,6 +129,10 @@ func (b *Bot) UpdateMessages(ctx context.Context, reqID string, data pd.AccessRe // FetchRecipient returns the recipient for the given raw recipient. func (b *Bot) FetchRecipient(ctx context.Context, name string) (*common.Recipient, error) { + return createScheduleRecipient(ctx, name) +} + +func createScheduleRecipient(ctx context.Context, name string) (*common.Recipient, error) { return &common.Recipient{ Name: name, ID: name,