Skip to content

Commit

Permalink
Add Loader interface and Get tests
Browse files Browse the repository at this point in the history
  • Loading branch information
swithek committed Feb 5, 2022
1 parent 2c1aa16 commit a84d130
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 66 deletions.
103 changes: 39 additions & 64 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ type Cache[K comparable, V any] struct {
insertFns []func(*Item[K, V])
}

loader Loader

loader Loader[K, V]
stopCh chan struct{}

capacity uint64
Expand All @@ -55,8 +54,7 @@ type Cache[K comparable, V any] struct {
// New creates a new instance of cache.
func New[K comparable, V any]() *Cache[K, V] {
c := &Cache[K, V]{
loaderGroup: &singleflight.Group{},
stopCh: make(chan struct{}),
stopCh: make(chan struct{}),
}
c.items.values = make(map[K]*list.Element)
c.items.lru = list.New()
Expand Down Expand Up @@ -256,37 +254,23 @@ func (c *Cache[K, V]) Get(key K) *Item[K, V] {
elem := c.get(key, true)
c.items.mu.Unlock()

if elem != nil {
if elem == nil {
c.metricsMu.Lock()
c.metrics.Hits++
c.metrics.Misses++
c.metricsMu.Unlock()

return elem.Value.(*Item[K, V])
}

c.metricsMu.Lock()
c.metrics.Misses++
c.metricsMu.Unlock()
if c.loader != nil {
return c.loader.Load(c, key)
}

if c.loader == nil {
return nil
}

c.items.mu.Lock()
defer c.items.mu.Unlock()

// we don't want to extend expiration if the item is found
elem := c.items.values[key]
if elem == nil {
val, ttl, ok := c.loader.Load(key)
if !ok {
return nil
}

return c.set(key, val, ttl)
}
c.metricsMu.Lock()
c.metrics.Hits++
c.metricsMu.Unlock()

return elem.(*Item[K, V])
return elem.Value.(*Item[K, V])
}

// Delete deletes an item from the cache. If the item associated with
Expand Down Expand Up @@ -449,68 +433,59 @@ func (c *Cache[K, V]) Stop() {

// Loader is an interface that handles missing data loading.
type Loader[K comparable, V any] interface {
// Load should return a value and its TTL by the provided key.
// The bool return value should indicate whether the value
// and TTL are valid or not (true == valid).
Load(key K) (V, time.Duration, bool)
// Load should execute a custom item retrieval logic and
// return the item that is associated with the key.
// It should return nil if the item is not found/valid.
// The method is allowed to fetch data from the cache instance
// or update it for future use.
Load(c *Cache[K, V], key K) *Item[K, V]
}

// LoaderFunc type is an adapter that allows the use of ordinary
// functions as data loaders.
type LoaderFunc[K comparable, V any] func(K) (V, time.Duration, bool)
type LoaderFunc[K comparable, V any] func(*Cache[K, V], K) *Item[K, V]

// Load returns a value and its TTL by the provided key.
// The bool return value indicates whether the value
// and TTL are valid or not (true == valid).
func (l LoaderFunc[K, V]) Load(key K) (V, time.Duration, bool) {
return l(key)
// Load executes a custom item retrieval logic and returns the item that
// is associated with the key.
// It returns nil if the item is not found/valid.
func (l LoaderFunc[K, V]) Load(c *Cache[K, V], key K) *Item[K, V] {
return l(c, key)
}

// SyncLoader wraps another Loader and supresses duplicate
// SuppressedLoader wraps another Loader and suppresses duplicate
// calls to its Load method.
type SyncLoader[K comparable, V any] struct {
type SuppressedLoader[K comparable, V any] struct {
Loader[K, V]

group *singleflight.Group
l Loader[K, V]
}

// Load returns a value and its TTL by the provided key.
// The bool return value indicates whether the value
// and TTL are valid or not (true == valid).
// Load executes a custom item retrieval logic and returns the item that
// is associated with the key.
// It returns nil if the item is not found/valid.
// It also ensures that only one execution of the wrapped Loader's Load
// method is in-flight for a given key at a time.
func (l *SyncLoader[K, V]) Load(key K) (V, time.Duration, bool) {
func (l *SuppressedLoader[K, V]) Load(c *Cache[K, V], key K) *Item[K, V] {
// there should be a better/generic way to create a
// singleflight Group's key. It's possible that a generic
// singleflight.Group will be introduced with/in go1.19+
strKey := fmt.Sprint(key)

// the error can be discarded since the singleflight.Group
// itself does not return any of its errors, it returns
// the errors that we return ourselves in the func below (all of
// them are nil)
resInterf, _, _ := l.group.Do(strKey, func() (interface{}, error) {
value, ttl, ok := l.l.Load(key)
if !ok {
// the error that we return ourselves in the func below, which
// is also nil
res, _, _ := l.group.Do(strKey, func() (interface{}, error) {
item := l.Loader.Load(c, key)
if item == nil {
return nil, nil
}

return &syncLoaderResult{
value: value,
ttl: ttl,
}, nil
return item, nil
})
if res == nil {
var empty V
return empty, 0, false
return nil
}

res := resInterf.(*syncLoaderResult)

return res.value, res.ttl, true
}

// syncLoaderResult is the result type of SyncLoader.
type syncLoaderResult[V any] struct {
value V
ttl time.Duration
return res.(*Item[K, V])
}
159 changes: 157 additions & 2 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/sync/singleflight"
)

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

func Test_New(t *testing.T) {
c := New[string, string]()
require.NotNil(t, c)
assert.NotNil(t, c.loaderGroup)
assert.NotNil(t, c.stopCh)
assert.NotNil(t, c.items.values)
assert.NotNil(t, c.items.lru)
Expand Down Expand Up @@ -463,7 +468,66 @@ func Test_Cache_Set(t *testing.T) {
}

func Test_Cache_Get(t *testing.T) {
//cache := prepCache(time.Hour, "test1", "test2", "test3")
const notFoundKey, foundKey = "notfound", "test1"
cc := map[string]struct {
Key string
Loader Loader[string, string]
Metrics Metrics
Result *Item[string, string]
}{
"Get without loader when item is not found": {
Key: notFoundKey,
Metrics: Metrics{
Misses: 1,
},
},
"Get with loader that returns non nil value when item is not found": {
Key: notFoundKey,
Loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return &Item[string, string]{key: "test"}
}),
Metrics: Metrics{
Misses: 1,
},
Result: &Item[string, string]{key: "test"},
},
"Get with loader that returns nil value when item is not found": {
Key: notFoundKey,
Loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return nil
}),
Metrics: Metrics{
Misses: 1,
},
},
"Get when item is found": {
Key: foundKey,
Metrics: Metrics{
Hits: 1,
},
},
}

for cn, c := range cc {
c := c

t.Run(cn, func(t *testing.T) {
t.Parallel()

cache := prepCache(time.Minute, foundKey, "test2", "test3")
cache.loader = c.Loader

res := cache.Get(c.Key)

if c.Key == foundKey {
c.Result = cache.items.values[foundKey].Value.(*Item[string, string])
assert.Equal(t, foundKey, cache.items.lru.Front().Value.(*Item[string, string]).key)
}

assert.Equal(t, c.Result, res)
assert.Equal(t, c.Metrics, cache.metrics)
})
}
}

func Test_Cache_Delete(t *testing.T) {
Expand Down Expand Up @@ -680,6 +744,97 @@ func Test_Cache_Stop(t *testing.T) {
assert.Len(t, cache.stopCh, 1)
}

func Test_LoaderFunc_Load(t *testing.T) {
var called bool

fn := LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
called = true
return nil
})

assert.Nil(t, fn(nil, ""))
assert.True(t, called)
}

func Test_SuppressedLoader_Load(t *testing.T) {
var (
mu sync.Mutex
loadCalls int
releaseCh = make(chan struct{})
res *Item[string, string]
)

l := SuppressedLoader[string, string]{
Loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
mu.Lock()
loadCalls++
mu.Unlock()

<-releaseCh

if res == nil {
return nil
}

res1 := *res

return &res1
}),
group: &singleflight.Group{},
}

var (
wg sync.WaitGroup
item1, item2 *Item[string, string]
)

cache := prepCache(time.Hour)

// nil result
wg.Add(2)

go func() {
item1 = l.Load(cache, "test")
wg.Done()
}()

go func() {
item2 = l.Load(cache, "test")
wg.Done()
}()

time.Sleep(time.Millisecond * 100) // wait for goroutines to halt
releaseCh <- struct{}{}

wg.Wait()
require.Nil(t, item1)
require.Nil(t, item2)
assert.Equal(t, 1, loadCalls)

// non nil result
res = &Item[string, string]{key: "test"}
loadCalls = 0
wg.Add(2)

go func() {
item1 = l.Load(cache, "test")
wg.Done()
}()

go func() {
item2 = l.Load(cache, "test")
wg.Done()
}()

time.Sleep(time.Millisecond * 100) // wait for goroutines to halt
releaseCh <- struct{}{}

wg.Wait()
require.Same(t, item1, item2)
assert.Equal(t, "test", item1.key)
assert.Equal(t, 1, loadCalls)
}

func prepCache(ttl time.Duration, keys ...string) *Cache[string, string] {
c := &Cache[string, string]{defaultTTL: ttl}
c.items.values = make(map[string]*list.Element)
Expand Down

0 comments on commit a84d130

Please sign in to comment.