Skip to content

Commit

Permalink
add security domain check
Browse files Browse the repository at this point in the history
yuwnloyblog committed Oct 24, 2024
1 parent da49cc7 commit 687514b
Showing 6 changed files with 48 additions and 7 deletions.
1 change: 1 addition & 0 deletions commons/errs/errorcode.go
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@ var (
IMErrorCode_CONNECT_CLOSE_PB_DECODE_FAIL IMErrorCode = 11017
IMErrorCode_CONNECT_CLOSE_HEARTBEAT_TIMEOUT IMErrorCode = 11018
IMErrorCode_CONNECT_CLOSE_DATA_ILLEGAL IMErrorCode = 11019
IMErrorCode_CONNECT_UNSECURITYDOMAIN IMErrorCode = 11020
)

// msg errorcode
21 changes: 17 additions & 4 deletions services/commonservices/appinfocache.go
Original file line number Diff line number Diff line change
@@ -21,8 +21,9 @@ type AppInfo struct {
AppStatus int `default:"-"`
CreatedTime time.Time `default:"-"`

EventSubConfigObj *EventSubConfigObj `default:"-"`
EventSubSwitchObj *EventSubSwitchObj `default:"-"`
EventSubConfigObj *EventSubConfigObj `default:"-"`
EventSubSwitchObj *EventSubSwitchObj `default:"-"`
SecurityDomainsObj *SecurityDomains `default:"-"`

TokenEffectiveMinute int `default:"0"`
OfflineMsgSaveTime int `default:"1440"`
@@ -44,8 +45,9 @@ type AppInfo struct {
OpenGrpSnapshot bool `default:"false"`
BigGrpThreshold int `default:"1000"`

EventSubConfig string `default:""`
EventSubSwitch string `default:""`
EventSubConfig string `default:""`
EventSubSwitch string `default:""`
SecurityDomains string `default:""`

// TestItem string
// TestInt int
@@ -98,6 +100,13 @@ func init() {
appInfo.EventSubSwitchObj = eventSubSwitch
}
}
if appInfo.SecurityDomainsObj == nil && appInfo.SecurityDomains != "" {
domains := &SecurityDomains{}
err := json.Unmarshal([]byte(appInfo.SecurityDomains), domains)
if err == nil {
appInfo.SecurityDomainsObj = domains
}
}
return appInfo
}
return notExistAppInfo
@@ -189,3 +198,7 @@ type EventSubSwitchObj struct {
OnlineSubSwitch int `json:"online_sub_switch"`
OfflineSubSwitch int `json:"offline_sub_switch"`
}

type SecurityDomains struct {
Domains []string `json:"domains"`
}
4 changes: 4 additions & 0 deletions services/connectmanager/server/imcontext/atachementutil.go
Original file line number Diff line number Diff line change
@@ -55,6 +55,10 @@ func GetDeviceId(ctx WsHandleContext) string {
return GetContextAttrString(ctx, StateKey_DeviceID)
}

func GetReferer(ctx WsHandleContext) string {
return GetContextAttrString(ctx, StateKey_Referer)
}

func GetInstanceId(ctx WsHandleContext) string {
return GetContextAttrString(ctx, StateKey_InstanceId)
}
1 change: 1 addition & 0 deletions services/connectmanager/server/imcontext/consts.go
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ const (
StateKey_Version string = "state.version"
StateKey_ClientIp string = "state.client_ip"
StateKey_Limiter string = "state.limiter"
StateKey_Referer string = "state.referer"
// StateKey_Extra string = "state.extra"
StateKey_InstanceId string = "state.instance_id"
)
11 changes: 8 additions & 3 deletions services/connectmanager/server/imwebsocketserver.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ import (
"im-server/services/connectmanager/server/codec"
"im-server/services/connectmanager/server/imcontext"
"net/http"
"strings"
"sync"
"time"

@@ -46,7 +47,10 @@ func (server *ImWebsocketServer) ImWsServer(w http.ResponseWriter, r *http.Reque
fmt.Println("Error during connect upgrade:", err)
return
}

referer := strings.TrimSpace(r.Header.Get("Origin"))
if referer == "" {
referer = strings.TrimSpace(r.Header.Get("Referer"))
}
child := &ImWebsocketChild{
stopChan: make(chan bool, 1),
wsConn: conn,
@@ -55,7 +59,7 @@ func (server *ImWebsocketServer) ImWsServer(w http.ResponseWriter, r *http.Reque
latestActiveTime: time.Now().UnixMilli(),
}
utils.SafeGo(func() {
child.startWsListener()
child.startWsListener(referer)
})
}
func (server *ImWebsocketServer) Stop() {
@@ -71,7 +75,7 @@ type ImWebsocketChild struct {
ticker *time.Ticker
}

func (child *ImWebsocketChild) startWsListener() {
func (child *ImWebsocketChild) startWsListener(referer string) {
handler := IMWebsocketMsgHandler{child.messageListener}
ctx := &WsHandleContextImpl{
conn: child.wsConn,
@@ -83,6 +87,7 @@ func (child *ImWebsocketChild) startWsListener() {
imcontext.SetContextAttr(ctx, imcontext.StateKey_ConnectCreateTime, time.Now().UnixMilli())
imcontext.SetContextAttr(ctx, imcontext.StateKey_CtxLocker, &sync.Mutex{})
imcontext.SetContextAttr(ctx, imcontext.StateKey_Limiter, rate.NewLimiter(100, 10))
imcontext.SetContextAttr(ctx, imcontext.StateKey_Referer, referer)

//start ticker
child.startTicker(ctx, handler)
17 changes: 17 additions & 0 deletions services/connectmanager/services/logincheck.go
Original file line number Diff line number Diff line change
@@ -28,6 +28,23 @@ func CheckLogin(ctx imcontext.WsHandleContext, msg *codec.ConnectMsgBody) (int32
if _, exist := supportPlatforms[msg.Platform]; !exist {
return int32(errs.IMErrorCode_CONNECT_UNSUPPROTEDPLATFORM), ""
}
//check security domain
if msg.Platform == string(commonservices.Platform_Web) {
referer := imcontext.GetReferer(ctx)
appInfo, exist := commonservices.GetAppInfo(appkey)
if exist && appInfo != nil && appInfo.SecurityDomainsObj != nil && len(appInfo.SecurityDomainsObj.Domains) > 0 {
isContains := false
for _, domain := range appInfo.SecurityDomainsObj.Domains {
if domain == referer {
isContains = true
break
}
}
if !isContains {
return int32(errs.IMErrorCode_CONNECT_UNSECURITYDOMAIN), ""
}
}
}
//check token
tokenWrap, err := tokens.ParseTokenString(tokenStr)
if err != nil {

0 comments on commit 687514b

Please sign in to comment.