Skip to content

Commit

Permalink
fix 0xrawsec#102: endpoint management fully with DB engine
Browse files Browse the repository at this point in the history
  • Loading branch information
qjerome committed Jan 28, 2022
1 parent 7e23ab3 commit 79de3ba
Show file tree
Hide file tree
Showing 11 changed files with 898 additions and 616 deletions.
2 changes: 1 addition & 1 deletion .github/coverage/badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
222 changes: 109 additions & 113 deletions .github/coverage/coverage.txt

Large diffs are not rendered by default.

20 changes: 12 additions & 8 deletions api/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,17 +525,21 @@ func (m *ManagerClient) FetchCommand() (*Command, error) {
return command, ErrNothingToDo
}

jsonCommand, err := ioutil.ReadAll(resp.Body)
if err != nil {
return command, fmt.Errorf("FetchCommand failed to read HTTP response body: %s", err)
}
if resp.StatusCode == http.StatusOK {
jsonCommand, err := ioutil.ReadAll(resp.Body)
if err != nil {
return command, fmt.Errorf("FetchCommand failed to read HTTP response body: %s", err)
}

// unmarshal command to be executed
if err := json.Unmarshal(jsonCommand, &command); err != nil {
return command, fmt.Errorf("FetchCommand failed to unmarshal command: %s", err)
}

// unmarshal command to be executed
if err := json.Unmarshal(jsonCommand, &command); err != nil {
return command, fmt.Errorf("FetchCommand failed to unmarshal command: %s", err)
return command, nil
}
return command, fmt.Errorf("FetchCommand unexpected HTTP status %d", resp.StatusCode)

return command, nil
}
return command, fmt.Errorf("FetchCommand failed, server cannot be authenticated")
}
Expand Down
116 changes: 12 additions & 104 deletions api/endpoint.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package api

import (
"sync"
"fmt"
"time"

"github.com/0xrawsec/sod"
Expand All @@ -27,7 +27,17 @@ type Endpoint struct {

// NewEndpoint returns a new Endpoint structure
func NewEndpoint(uuid, key string) *Endpoint {
return &Endpoint{Uuid: uuid, Key: key}
e := &Endpoint{Uuid: uuid, Key: key}
e.Initialize(e.Uuid)
return e
}

// Validate overwrite sod.Item function
func (e *Endpoint) Validate() error {
if e.Criticality < 0 || e.Criticality > 10 {
return fmt.Errorf("criticality field must be in [0;10]")
}
return nil
}

// Copy returns a pointer to a new copy of the Endpoint
Expand All @@ -40,105 +50,3 @@ func (e *Endpoint) Copy() *Endpoint {
func (e *Endpoint) UpdateLastConnection() {
e.LastConnection = time.Now().UTC()
}

// Endpoints structure used to manage endpoints
// This struct looks over complicated for what it
// does but it is because it was more complex before
// and got simplified (too lazy to change it...)
type Endpoints struct {
sync.RWMutex
endpoints []*Endpoint
mapUUID map[string]int
}

// NewEndpoints creates a new Endpoints structure
func NewEndpoints() Endpoints {
return Endpoints{
endpoints: make([]*Endpoint, 0),
mapUUID: make(map[string]int),
}
}

// Add adds an Endpoint to the Endpoints
func (es *Endpoints) Add(e *Endpoint) {
es.Lock()
defer es.Unlock()
es.endpoints = append(es.endpoints, e)
es.mapUUID[e.Uuid] = len(es.endpoints) - 1
}

// DelByUUID deletes an Endpoint by its UUID
func (es *Endpoints) DelByUUID(uuid string) {
es.Lock()
defer es.Unlock()
if i, ok := es.mapUUID[uuid]; ok {
delete(es.mapUUID, uuid)

switch {
case i == 0:
if len(es.endpoints) == 1 {
es.endpoints = make([]*Endpoint, 0)
} else {
es.endpoints = es.endpoints[i+1:]
}
case i == len(es.endpoints)-1:
es.endpoints = es.endpoints[:i]
default:
es.endpoints = append(es.endpoints[:i], es.endpoints[i+1:]...)
}
}
}

func (es *Endpoints) HasByUUID(uuid string) bool {
es.RLock()
defer es.RUnlock()
_, ok := es.mapUUID[uuid]
return ok
}

// GetByUUID returns a reference to the copy of an Endpoint by its UUID
func (es *Endpoints) GetByUUID(uuid string) (*Endpoint, bool) {
es.RLock()
defer es.RUnlock()
if i, ok := es.mapUUID[uuid]; ok {
return es.endpoints[i].Copy(), true
}
return nil, false
}

// GetMutByUUID returns reference to an Endpoint
func (es *Endpoints) GetMutByUUID(uuid string) (*Endpoint, bool) {
es.RLock()
defer es.RUnlock()
if i, ok := es.mapUUID[uuid]; ok {
return es.endpoints[i], true
}
return nil, false
}

// Len returns the number of endpoints
func (es *Endpoints) Len() int {
es.RLock()
defer es.RUnlock()
return len(es.endpoints)
}

// Endpoints returns a list of references to copies of the endpoints
func (es *Endpoints) Endpoints() []*Endpoint {
es.RLock()
defer es.RUnlock()
endpts := make([]*Endpoint, 0, len(es.endpoints))
for _, e := range es.endpoints {
endpts = append(endpts, e.Copy())
}
return endpts
}

// MutEndpoints returns a list of references of the endpoints
func (es *Endpoints) MutEndpoints() []*Endpoint {
es.RLock()
defer es.RUnlock()
endpts := make([]*Endpoint, len(es.endpoints))
copy(endpts, es.endpoints)
return endpts
}
58 changes: 42 additions & 16 deletions api/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ type Manager struct {
detectionLogger *logger.EventLogger
detectionSearcher *logger.EventSearcher
endpointAPI *http.Server
endpoints Endpoints
adminAPI *http.Server
stop chan bool
done bool
//endpoints Endpoints
adminAPI *http.Server
stop chan bool
done bool

// Gene related members
gene struct {
Expand All @@ -261,7 +261,6 @@ type Manager struct {
// NewManager creates a new WHIDS manager with a logfile as parameter
func NewManager(c *ManagerConfig) (*Manager, error) {
var err error
var objects []sod.Object

m := Manager{Config: c, iocs: ioc.NewIocs()}
//logPath := filepath.Join(c.Logging.Root, c.Logging.LogBasename)
Expand Down Expand Up @@ -296,16 +295,6 @@ func NewManager(c *ManagerConfig) (*Manager, error) {
// initialize IoCs from db
m.iocs.FromDB(m.db)

// Endpoints initialization
m.endpoints = NewEndpoints()
if objects, err = m.db.All(&Endpoint{}); err != nil {
return nil, err
}
for _, o := range objects {
ept := o.(*Endpoint)
m.endpoints.Add(ept)
}

m.stop = make(chan bool)
if err = c.TLS.Verify(); err != nil && !c.TLS.Empty() {
return nil, err
Expand Down Expand Up @@ -405,6 +394,38 @@ func (m *Manager) updateRulesCache() {
m.gene.sha256 = hex.EncodeToString(sha256.Sum(nil))
}

// MutEndpoint returns an Endpoint pointer from database
// Result must be handled with care as any change to the Endpoint
// might be commited to the database. If an Endpoint needs to be
// modified but changes don't need to be commited, use Endpoint.Copy()
// to work on a copy
func (m *Manager) MutEndpoint(uuid string) (*Endpoint, bool) {
if o, err := m.db.GetByUUID(&Endpoint{}, uuid); err == nil {
// we return copy to endpoints not to modify cached structures
return o.(*Endpoint), true
}
return nil, false
}

// MutEndpoints returns a slice of Endpoint pointers from database
// Result must be handled with care as any change to the Endpoint
// might be commited to the database. If an Endpoint needs to be
// modified but changes don't need to be commited, use Endpoint.Copy()
// to work on a copy
func (m *Manager) MutEndpoints() (endpoints []*Endpoint, err error) {
var all []sod.Object

if all, err = m.db.All(&Endpoint{}); err != nil {
return
}
endpoints = make([]*Endpoint, 0, len(all))
for _, o := range all {
// we return copy to endpoints not to modify cached structures
endpoints = append(endpoints, o.(*Endpoint))
}
return
}

func (m *Manager) ImportRules(directory string) (err error) {
engine := engine.NewEngine()
engine.SetDumpRaw(true)
Expand Down Expand Up @@ -444,7 +465,7 @@ func (m *Manager) CreateNewAdminAPIUser(user *AdminAPIUser) (err error) {

// AddEndpoint adds new endpoint to the manager
func (m *Manager) AddEndpoint(uuid, key string) {
m.endpoints.Add(NewEndpoint(uuid, key))
m.db.InsertOrUpdate(NewEndpoint(uuid, key))
}

// UpdateReducer updates the reducer member of the Manager
Expand Down Expand Up @@ -494,6 +515,11 @@ func (m *Manager) Shutdown() (lastErr error) {
if err := m.eventLogger.Close(); err != nil {
lastErr = err
}

if err := m.db.Close(); err != nil {
lastErr = err
}

return
}

Expand Down
Loading

0 comments on commit 79de3ba

Please sign in to comment.