Skip to content

Commit

Permalink
Add functional options
Browse files Browse the repository at this point in the history
  • Loading branch information
swithek committed Feb 10, 2022
1 parent 3a7c2e1 commit ff13e7b
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 34 deletions.
31 changes: 18 additions & 13 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,12 @@ type Cache[K comparable, V any] struct {
}
}

stopCh chan struct{}

loader Loader[K, V]
capacity uint64
defaultTTL time.Duration
stopCh chan struct{}
options options[K, V]
}

// New creates a new instance of cache.
func New[K comparable, V any]() *Cache[K, V] {
func New[K comparable, V any](opts ...Option[K, V]) *Cache[K, V] {
c := &Cache[K, V]{
stopCh: make(chan struct{}),
}
Expand All @@ -71,6 +68,8 @@ func New[K comparable, V any]() *Cache[K, V] {
c.events.insertion.fns = make(map[uint64]func(*Item[K, V]))
c.events.eviction.fns = make(map[uint64]func(EvictionReason, *Item[K, V]))

applyOptions(&c.options, opts...)

return c
}

Expand Down Expand Up @@ -128,7 +127,7 @@ func (c *Cache[K, V]) updateExpirations(fresh bool, elem *list.Element) {
// Not concurrently safe.
func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] {
if ttl == DefaultTTL {
ttl = c.defaultTTL
ttl = c.options.ttl
}

elem := c.get(key, false)
Expand All @@ -142,7 +141,7 @@ func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] {
return item
}

if c.capacity != 0 && uint64(len(c.items.values)) >= c.capacity {
if c.options.capacity != 0 && uint64(len(c.items.values)) >= c.options.capacity {
// delete the oldest item
c.evict(EvictionReasonCapacityReached, c.items.lru.Back())
}
Expand Down Expand Up @@ -252,7 +251,13 @@ func (c *Cache[K, V]) Set(key K, value V, ttl time.Duration) *Item[K, V] {

// Get retrieves an item from the cache by the provided key.
// If the item is not found, a nil value is returned.
func (c *Cache[K, V]) Get(key K) *Item[K, V] {
func (c *Cache[K, V]) Get(key K, opts ...Option[K, V]) *Item[K, V] {
getOpts := options[K, V]{
loader: c.options.loader,
}

applyOptions(&getOpts, opts...)

c.items.mu.Lock()
elem := c.get(key, true)
c.items.mu.Unlock()
Expand All @@ -262,8 +267,8 @@ func (c *Cache[K, V]) Get(key K) *Item[K, V] {
c.metrics.Misses++
c.metricsMu.Unlock()

if c.loader != nil {
return c.loader.Load(c, key)
if getOpts.loader != nil {
return getOpts.loader.Load(c, key)
}

return nil
Expand Down Expand Up @@ -393,8 +398,8 @@ func (c *Cache[K, V]) Start() {
return d
}

if c.defaultTTL > 0 {
return c.defaultTTL
if c.options.ttl > 0 {
return c.options.ttl
}

return time.Hour
Expand Down
88 changes: 67 additions & 21 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ func TestMain(m *testing.M) {
}

func Test_New(t *testing.T) {
c := New[string, string]()
c := New[string, string](
WithTTL[string, string](time.Hour),
WithCapacity[string, string](1),
)
require.NotNil(t, c)
assert.NotNil(t, c.stopCh)
assert.NotNil(t, c.items.values)
assert.NotNil(t, c.items.lru)
assert.NotNil(t, c.items.expQueue)
assert.NotNil(t, c.items.timerCh)
assert.NotNil(t, c.events.insertion.fns)
assert.NotNil(t, c.events.eviction.fns)
assert.Equal(t, time.Hour, c.options.ttl)
assert.Equal(t, uint64(1), c.options.capacity)
}

func Test_Cache_updateExpirations(t *testing.T) {
Expand Down Expand Up @@ -241,8 +248,8 @@ func Test_Cache_set(t *testing.T) {
)

cache := prepCache(time.Hour, evictedKey, existingKey, "test3")
cache.capacity = c.Capacity
cache.defaultTTL = time.Minute * 20
cache.options.capacity = c.Capacity
cache.options.ttl = time.Minute * 20
cache.events.insertion.fns[1] = func(item *Item[string, string]) {
assert.Equal(t, newKey, item.key)
insertFnsMu.Lock()
Expand Down Expand Up @@ -298,8 +305,8 @@ func Test_Cache_set(t *testing.T) {

switch {
case c.TTL == DefaultTTL:
assert.Equal(t, cache.defaultTTL, item.ttl)
assert.WithinDuration(t, time.Now(), item.expiresAt, cache.defaultTTL)
assert.Equal(t, cache.options.ttl, item.ttl)
assert.WithinDuration(t, time.Now(), item.expiresAt, cache.options.ttl)
assert.Equal(t, c.Key, cache.items.expQueue[0].Value.(*Item[string, string]).key)
case c.TTL > DefaultTTL:
assert.Equal(t, c.TTL, item.ttl)
Expand Down Expand Up @@ -470,32 +477,70 @@ func Test_Cache_Set(t *testing.T) {
func Test_Cache_Get(t *testing.T) {
const notFoundKey, foundKey = "notfound", "test1"
cc := map[string]struct {
Key string
Loader Loader[string, string]
Metrics Metrics
Result *Item[string, string]
Key string
DefaultOptions options[string, string]
CallOptions []Option[string, string]
Metrics Metrics
Result *Item[string, string]
}{
"Get without loader when item is not found": {
Key: notFoundKey,
Metrics: Metrics{
Misses: 1,
},
},
"Get with loader that returns non nil value when item is not found": {
"Get with default loader that returns non nil value when item is not found": {
Key: notFoundKey,
Loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return &Item[string, string]{key: "test"}
}),
DefaultOptions: options[string, string]{
loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return &Item[string, string]{key: "test"}
}),
},
Metrics: Metrics{
Misses: 1,
},
Result: &Item[string, string]{key: "test"},
},
"Get with loader that returns nil value when item is not found": {
"Get with default loader that returns nil value when item is not found": {
Key: notFoundKey,
Loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return nil
}),
DefaultOptions: options[string, string]{
loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return nil
}),
},
Metrics: Metrics{
Misses: 1,
},
},
"Get with call loader that returns non nil value when item is not found": {
Key: notFoundKey,
DefaultOptions: options[string, string]{
loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return &Item[string, string]{key: "test"}
}),
},
CallOptions: []Option[string, string]{
WithLoader[string, string](LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return &Item[string, string]{key: "hello"}
})),
},
Metrics: Metrics{
Misses: 1,
},
Result: &Item[string, string]{key: "hello"},
},
"Get with call loader that returns nil value when item is not found": {
Key: notFoundKey,
DefaultOptions: options[string, string]{
loader: LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return &Item[string, string]{key: "test"}
}),
},
CallOptions: []Option[string, string]{
WithLoader[string, string](LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return nil
})),
},
Metrics: Metrics{
Misses: 1,
},
Expand All @@ -515,9 +560,9 @@ func Test_Cache_Get(t *testing.T) {
t.Parallel()

cache := prepCache(time.Minute, foundKey, "test2", "test3")
cache.loader = c.Loader
cache.options = c.DefaultOptions

res := cache.Get(c.Key)
res := cache.Get(c.Key, c.CallOptions...)

if c.Key == foundKey {
c.Result = cache.items.values[foundKey].Value.(*Item[string, string])
Expand Down Expand Up @@ -720,7 +765,7 @@ func Test_Cache_Start(t *testing.T) {
cache.items.mu.Lock()
addToCache(cache, time.Nanosecond, "2")
cache.items.mu.Unlock()
cache.defaultTTL = time.Hour
cache.options.ttl = time.Hour
cache.items.timerCh <- time.Millisecond
case 2:
cache.items.mu.Lock()
Expand Down Expand Up @@ -872,7 +917,8 @@ func Test_SuppressedLoader_Load(t *testing.T) {
}

func prepCache(ttl time.Duration, keys ...string) *Cache[string, string] {
c := &Cache[string, string]{defaultTTL: ttl}
c := &Cache[string, string]{}
c.options.ttl = ttl
c.items.values = make(map[string]*list.Element)
c.items.lru = list.New()
c.items.expQueue = newExpirationQueue[string, string]()
Expand Down
55 changes: 55 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package ttlcache

import "time"

// Option sets a specific cache option.
type Option[K comparable, V any] interface {
apply(opts *options[K, V])
}

// optionFunc wraps a function and implements the Option interface.
type optionFunc[K comparable, V any] func(*options[K, V])

// apply calls the wrapped function.
func (fn optionFunc[K, V]) apply(opts *options[K, V]) {
fn(opts)
}

// options holds all available cache configuration options.
type options[K comparable, V any] struct {
capacity uint64
ttl time.Duration
loader Loader[K, V]
}

// applyOptions applies the provided option values to the option struct.
func applyOptions[K comparable, V any](v *options[K, V], opts ...Option[K, V]) {
for i := range opts {
opts[i].apply(v)
}
}

// WithCapacity sets the maximum capacity of the cache.
// It has no effect when passing into Get().
func WithCapacity[K comparable, V any](c uint64) Option[K, V] {
return optionFunc[K, V](func(opts *options[K, V]) {
opts.capacity = c
})
}

// WithTTL sets the TTL of the cache.
// It has no effect when passing into Get().
func WithTTL[K comparable, V any](ttl time.Duration) Option[K, V] {
return optionFunc[K, V](func(opts *options[K, V]) {
opts.ttl = ttl
})
}

// WithLoader sets the loader of the cache.
// When passing into Get(), it sets an epheral loader that
// is used instead of the cache's default one.
func WithLoader[K comparable, V any](l Loader[K, V]) Option[K, V] {
return optionFunc[K, V](func(opts *options[K, V]) {
opts.loader = l
})
}
53 changes: 53 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package ttlcache

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func Test_optionFunc_apply(t *testing.T) {
var called bool

optionFunc[string, string](func(_ *options[string, string]) {
called = true
}).apply(nil)
assert.True(t, called)
}

func Test_applyOptions(t *testing.T) {
var opts options[string, string]

applyOptions(&opts,
WithCapacity[string, string](12),
WithTTL[string, string](time.Hour),
)

assert.Equal(t, uint64(12), opts.capacity)
assert.Equal(t, time.Hour, opts.ttl)
}

func Test_WithCapacity(t *testing.T) {
var opts options[string, string]

WithCapacity[string, string](12).apply(&opts)
assert.Equal(t, uint64(12), opts.capacity)
}

func Test_WithTTL(t *testing.T) {
var opts options[string, string]

WithTTL[string, string](time.Hour).apply(&opts)
assert.Equal(t, time.Hour, opts.ttl)
}

func Test_WithLoader(t *testing.T) {
var opts options[string, string]

l := LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] {
return nil
})
WithLoader[string, string](l).apply(&opts)
assert.NotNil(t, opts.loader)
}

0 comments on commit ff13e7b

Please sign in to comment.