Skip to content

Commit

Permalink
Optimize code by write a cache package
Browse files Browse the repository at this point in the history
Signed-off-by: Jianhui Zhao <[email protected]>
  • Loading branch information
Jianhui Zhao committed Apr 26, 2019
1 parent 700ea91 commit 80cee19
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 30 deletions.
81 changes: 81 additions & 0 deletions cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package cache

import (
"runtime"
"sync"
"time"
)

type Item struct {
value interface{}
expiration int64
}

type Cache struct {
items sync.Map
defaultExpiration time.Duration
gcInterval time.Duration
stop chan struct{}
}

// Delete all expired items from the cache.
func (c *Cache) DeleteExpired() {
now := time.Now().UnixNano()

c.items.Range(func(key, value interface{}) bool {
if value := value.(*Item); value.expiration > 0 && now > value.expiration {
c.items.Delete(key)
}
return true
})
}

func (c *Cache) gcLoop() {
ticker := time.NewTicker(c.gcInterval)
for {
select {
case <-ticker.C:
c.DeleteExpired()
case <-c.stop:
ticker.Stop()
return
}
}
}

func New(defaultExpiration, gcInterval time.Duration) *Cache {
c := &Cache{
defaultExpiration: defaultExpiration,
gcInterval: gcInterval,
stop: make(chan struct{}),
}

go c.gcLoop()

runtime.SetFinalizer(c, func(c *Cache) {
c.stop <- struct{}{}
})
return c
}

func (c *Cache) Set(key, value interface{}, d time.Duration) {
var e int64

if d == 0 {
d = c.defaultExpiration
}

if d > 0 {
e = time.Now().Add(d).UnixNano()
}

c.items.Store(key, &Item{value, e})
}

func (c *Cache) Get(key interface{}) (interface{}, bool) {
return c.items.Load(key)
}

func (c *Cache) Del(key interface{}) {
c.items.Delete(key)
}
39 changes: 9 additions & 30 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,37 @@ import (
"fmt"
"github.com/rakyll/statik/fs"
log "github.com/sirupsen/logrus"
"github.com/zhaojh329/rttys/cache"
_ "github.com/zhaojh329/rttys/statik"
"net/http"
"os"
"strconv"
"sync"
"time"
)

type HttpSession struct {
active time.Duration
}

const MAX_SESSION_TIME = 30 * time.Minute

var httpSessions sync.Map
var httpSessions *cache.Cache

func allowOrigin(w http.ResponseWriter) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("content-type", "application/json")
}

func cleanHttpSession() {
httpSessions.Range(func(k, v interface{}) bool {
sid := k.(string)
s := v.(*HttpSession)

s.active = s.active - time.Second
if s.active == 0 {
httpSessions.Delete(sid)
}

return true
})

time.AfterFunc(5*time.Second, cleanHttpSession)
}

func httpAuth(w http.ResponseWriter, r *http.Request) bool {
c, err := r.Cookie("sid")
if err != nil {
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}

s, ok := httpSessions.Load(c.Value)
if !ok {
if _, ok := httpSessions.Get(c.Value); !ok {
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}

(s.(*HttpSession)).active = MAX_SESSION_TIME
// Update
httpSessions.Del(c.Value)
httpSessions.Set(c.Value, true, 0)

return true
}
Expand All @@ -78,6 +57,8 @@ func httpLogin(cfg *RttysConfig, username, password string) bool {
}

func httpStart(br *Broker, cfg *RttysConfig) {
httpSessions = cache.New(30*time.Minute, 5*time.Second)

statikFS, err := fs.New()
if err != nil {
log.Fatal(err)
Expand All @@ -100,7 +81,7 @@ func httpStart(br *Broker, cfg *RttysConfig) {

if httpLogin(cfg, username, password) {
sid := genUniqueID("http")
httpSessions.Store(sid, &HttpSession{active: MAX_SESSION_TIME})
httpSessions.Set(sid, true, 0)

http.SetCookie(w, &http.Cookie{
Name: "sid",
Expand Down Expand Up @@ -147,8 +128,6 @@ func httpStart(br *Broker, cfg *RttysConfig) {
staticfs.ServeHTTP(w, r)
})

time.AfterFunc(5*time.Second, cleanHttpSession)

if cfg.sslCert != "" && cfg.sslKey != "" {
_, err := os.Lstat(cfg.sslCert)
if err != nil {
Expand Down

0 comments on commit 80cee19

Please sign in to comment.