Skip to content

Commit

Permalink
🐛 [Bug]: cache middleware: runtime error: index out of range [0] with…
Browse files Browse the repository at this point in the history
… length 0 (gofiber#3075)

Resolves gofiber#3072

Signed-off-by: brunodmartins <[email protected]>
  • Loading branch information
brunodmartins authored Jul 23, 2024
1 parent a57b3c0 commit f413bfe
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 55 deletions.
80 changes: 42 additions & 38 deletions middleware/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,46 +117,49 @@ func New(config ...Config) fiber.Handler {
// Get timestamp
ts := atomic.LoadUint64(&timestamp)

// Invalidate cache if requested
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) && e != nil {
e.exp = ts - 1
}

// Check if entry is expired
if e.exp != 0 && ts >= e.exp {
deleteKey(key)
if cfg.MaxBytes > 0 {
_, size := heap.remove(e.heapidx)
storedBytes -= size
}
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
// Separate body value to avoid msgp serialization
// We can store raw bytes with Storage 👍
if cfg.Storage != nil {
e.body = manager.getRaw(key + "_body")
}
// Set response headers from cache
c.Response().SetBodyRaw(e.body)
c.Response().SetStatusCode(e.status)
c.Response().Header.SetContentTypeBytes(e.ctype)
if len(e.cencoding) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
}
for k, v := range e.headers {
c.Response().Header.SetBytesV(k, v)
// Cache Entry not found
if e != nil {
// Invalidate cache if requested
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) {
e.exp = ts - 1
}
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatUint(e.exp-ts, 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}

c.Set(cfg.CacheHeader, cacheHit)

mux.Unlock()

// Return response
return nil
// Check if entry is expired
if e.exp != 0 && ts >= e.exp {
deleteKey(key)
if cfg.MaxBytes > 0 {
_, size := heap.remove(e.heapidx)
storedBytes -= size
}
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
// Separate body value to avoid msgp serialization
// We can store raw bytes with Storage 👍
if cfg.Storage != nil {
e.body = manager.getRaw(key + "_body")
}
// Set response headers from cache
c.Response().SetBodyRaw(e.body)
c.Response().SetStatusCode(e.status)
c.Response().Header.SetContentTypeBytes(e.ctype)
if len(e.cencoding) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
}
for k, v := range e.headers {
c.Response().Header.SetBytesV(k, v)
}
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatUint(e.exp-ts, 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}

c.Set(cfg.CacheHeader, cacheHit)

mux.Unlock()

// Return response
return nil
}
}

// make sure we're not blocking concurrent requests - do unlock
Expand Down Expand Up @@ -193,6 +196,7 @@ func New(config ...Config) fiber.Handler {
}
}

e = manager.acquire()
// Cache response
e.body = utils.CopyBytes(c.Response().Body())
e.status = c.Response().StatusCode()
Expand Down
134 changes: 120 additions & 14 deletions middleware/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ func Test_Cache_Expired(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 2 * time.Second}))

count := 0
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(strconv.FormatInt(time.Now().UnixNano(), 10))
count++
return c.SendString(strconv.Itoa(count))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
Expand Down Expand Up @@ -86,9 +87,10 @@ func Test_Cache(t *testing.T) {
app := fiber.New()
app.Use(New())

count := 0
app.Get("/", func(c fiber.Ctx) error {
now := strconv.FormatInt(time.Now().UnixNano(), 10)
return c.SendString(now)
count++
return c.SendString(strconv.Itoa(count))
})

req := httptest.NewRequest(fiber.MethodGet, "/", nil)
Expand Down Expand Up @@ -305,9 +307,10 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
cache := New(Config{Expiration: 0 * time.Second})
app.Use(cache)

count := 0
app.Get("/", func(c fiber.Ctx) error {
now := strconv.FormatInt(time.Now().UnixNano(), 10)
return c.SendString(now)
count++
return c.SendString(strconv.Itoa(count))
})

req := httptest.NewRequest(fiber.MethodGet, "/", nil)
Expand Down Expand Up @@ -414,8 +417,10 @@ func Test_Cache_NothingToCache(t *testing.T) {

app.Use(New(Config{Expiration: -(time.Second * 1)}))

count := 0
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(time.Now().String())
count++
return c.SendString(strconv.Itoa(count))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
Expand Down Expand Up @@ -447,12 +452,16 @@ func Test_Cache_CustomNext(t *testing.T) {
CacheControl: true,
}))

count := 0
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(time.Now().String())
count++
return c.SendString(strconv.Itoa(count))
})

errorCount := 0
app.Get("/error", func(c fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
errorCount++
return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(errorCount))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
Expand Down Expand Up @@ -508,9 +517,11 @@ func Test_CustomExpiration(t *testing.T) {
return time.Second * time.Duration(newCacheTime)
}}))

count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
c.Response().Header.Add("Cache-Time", "1")
return c.SendString(strconv.FormatInt(time.Now().UnixNano(), 10))
return c.SendString(strconv.Itoa(count))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
Expand Down Expand Up @@ -588,8 +599,11 @@ func Test_CacheHeader(t *testing.T) {
return c.SendString(fiber.Query[string](c, "cache"))
})

count := 0
app.Get("/error", func(c fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
count++
c.Response().Header.Add("Cache-Time", "1")
return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(count))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
Expand All @@ -615,10 +629,13 @@ func Test_Cache_WithHead(t *testing.T) {
app := fiber.New()
app.Use(New())

count := 0
handler := func(c fiber.Ctx) error {
now := strconv.FormatInt(time.Now().UnixNano(), 10)
return c.SendString(now)
count++
c.Response().Header.Add("Cache-Time", "1")
return c.SendString(strconv.Itoa(count))
}

app.Route("/").Get(handler).Head(handler)

req := httptest.NewRequest(fiber.MethodHead, "/", nil)
Expand Down Expand Up @@ -708,8 +725,10 @@ func Test_CacheInvalidation(t *testing.T) {
},
}))

count := 0
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(time.Now().String())
count++
return c.SendString(strconv.Itoa(count))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
Expand All @@ -731,6 +750,93 @@ func Test_CacheInvalidation(t *testing.T) {
require.NotEqual(t, body, bodyInvalidate)
}

func Test_CacheInvalidation_noCacheEntry(t *testing.T) {
t.Parallel()
t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) {
t.Parallel()
app := fiber.New()
cacheInvalidatorExecuted := false
app.Use(New(Config{
CacheControl: true,
CacheInvalidator: func(c fiber.Ctx) bool {
cacheInvalidatorExecuted = true
return fiber.Query[bool](c, "invalidate")
},
MaxBytes: 10 * 1024 * 1024,
}))
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil))
require.NoError(t, err)
require.False(t, cacheInvalidatorExecuted)
})
}

func Test_CacheInvalidation_removeFromHeap(t *testing.T) {
t.Parallel()
t.Run("Invalidate and remove from the heap", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheControl: true,
CacheInvalidator: func(c fiber.Ctx) bool {
return fiber.Query[bool](c, "invalidate")
},
MaxBytes: 10 * 1024 * 1024,
}))

count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
require.True(t, bytes.Equal(body, bodyCached))
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))

respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil))
require.NoError(t, err)
bodyInvalidate, err := io.ReadAll(respInvalidate.Body)
require.NoError(t, err)
require.NotEqual(t, body, bodyInvalidate)
})
}

func Test_CacheStorage_CustomHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheControl: true,
Storage: memory.New(),
MaxBytes: 10 * 1024 * 1024,
}))

app.Get("/", func(c fiber.Ctx) error {
c.Response().Header.Set("Content-Type", "text/xml")
c.Response().Header.Set("Content-Encoding", "utf8")
return c.Send([]byte("<xml><value>Test</value></xml>"))
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
require.True(t, bytes.Equal(body, bodyCached))
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
}

// Because time points are updated once every X milliseconds, entries in tests can often have
// equal expiration times and thus be in an random order. This closure hands out increasing
// time intervals to maintain strong ascending order of expiration
Expand Down
2 changes: 1 addition & 1 deletion middleware/cache/heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type heapEntry struct {
// elements in constant time. It does so by handing out special indices
// and tracking entry movement.
//
// indexdedHeap is used for quickly finding entries with the lowest
// indexedHeap is used for quickly finding entries with the lowest
// expiration timestamp and deleting arbitrary entries.
type indexedHeap struct {
// Slice the heap is built on
Expand Down
3 changes: 1 addition & 2 deletions middleware/cache/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ func (m *manager) get(key string) *item {
return it
}
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire()
return it
return nil
}
return it
}
Expand Down
26 changes: 26 additions & 0 deletions middleware/cache/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package cache

import (
"testing"
"time"

"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/assert"
)

func Test_manager_get(t *testing.T) {
t.Parallel()
cacheManager := newManager(nil)
t.Run("Item not found in cache", func(t *testing.T) {
t.Parallel()
assert.Nil(t, cacheManager.get(utils.UUID()))
})
t.Run("Item found in cache", func(t *testing.T) {
t.Parallel()
id := utils.UUID()
cacheItem := cacheManager.acquire()
cacheItem.body = []byte("test-body")
cacheManager.set(id, cacheItem, 10*time.Second)
assert.NotNil(t, cacheManager.get(id))
})
}

0 comments on commit f413bfe

Please sign in to comment.