Skip to content

Commit

Permalink
feat: add ConcurrentMap
Browse files Browse the repository at this point in the history
  • Loading branch information
duke-git committed Jul 24, 2023
1 parent fe0264f commit fe0cb04
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 0 deletions.
154 changes: 154 additions & 0 deletions maputil/concurrentmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright 2021 [email protected]. All rights reserved.
// Use of this source code is governed by MIT license

// Package maputil includes some functions to manipulate map.
package maputil

import (
"fmt"
"sync"
)

const defaultShardCount = 32

// ConcurrentMap is like map, but is safe for concurrent use by multiple goroutines.
type ConcurrentMap[K comparable, V any] struct {
shardCount uint64
locks []sync.RWMutex
maps []map[K]V
}

// NewConcurrentMap create a ConcurrentMap with specific shard count.
func NewConcurrentMap[K comparable, V any](shardCount int) *ConcurrentMap[K, V] {
if shardCount <= 0 {
shardCount = defaultShardCount
}

cm := &ConcurrentMap[K, V]{
shardCount: uint64(shardCount),
locks: make([]sync.RWMutex, shardCount),
maps: make([]map[K]V, shardCount),
}

for i := range cm.maps {
cm.maps[i] = make(map[K]V)
}

return cm
}

// Set the value for a key.
// Play: todo
func (cm *ConcurrentMap[K, V]) Set(key K, value V) {
shard := cm.getShard(key)

cm.locks[shard].Lock()
cm.maps[shard][key] = value

cm.locks[shard].Unlock()
}

// Get the value stored in the map for a key, or nil if no.
// Play: todo
func (cm *ConcurrentMap[K, V]) Get(key K) (V, bool) {
shard := cm.getShard(key)

cm.locks[shard].RLock()
value, ok := cm.maps[shard][key]
cm.locks[shard].RUnlock()

return value, ok
}

// GetOrSet returns the existing value for the key if present.
// Otherwise, it sets and returns the given value.
// Play: todo
func (cm *ConcurrentMap[K, V]) GetOrSet(key K, value V) (actual V, ok bool) {
shard := cm.getShard(key)

cm.locks[shard].RLock()
if actual, ok := cm.maps[shard][key]; ok {
cm.locks[shard].RUnlock()
return actual, ok
}
cm.locks[shard].RUnlock()

// lock again
cm.locks[shard].Lock()
if actual, ok = cm.maps[shard][key]; ok {
cm.locks[shard].Unlock()
return
}

cm.maps[shard][key] = value
cm.locks[shard].Unlock()

return value, ok
}

// Delete the value for a key.
// Play: todo
func (cm *ConcurrentMap[K, V]) Delete(key K) {
shard := cm.getShard(key)

cm.locks[shard].Lock()
delete(cm.maps[shard], key)
cm.locks[shard].Unlock()
}

// GetAndDelete returns the existing value for the key if present and then delete the value for the key.
// Otherwise, do nothing, just return false
// Play: todo
func (cm *ConcurrentMap[K, V]) GetAndDelete(key K) (actual V, ok bool) {
shard := cm.getShard(key)

cm.locks[shard].RLock()
if actual, ok = cm.maps[shard][key]; ok {
cm.locks[shard].RUnlock()
cm.Delete(key)
return
}
cm.locks[shard].RUnlock()

return actual, false
}

// Has checks if map has the value for a key.
// Play: todo
func (cm *ConcurrentMap[K, V]) Has(key K) bool {
_, ok := cm.Get(key)
return ok
}

// Range calls iterator sequentially for each key and value present in each of the shards in the map.
// If iterator returns false, range stops the iteration.
func (cm *ConcurrentMap[K, V]) Range(iterator func(key K, value V) bool) {
for shard := range cm.locks {
cm.locks[shard].RLock()

for k, v := range cm.maps[shard] {
if !iterator(k, v) {
cm.locks[shard].RUnlock()
return
}
}
cm.locks[shard].RUnlock()
}
}

// getShard get shard by a key.
func (cm *ConcurrentMap[K, V]) getShard(key K) uint64 {
hash := fnv32(fmt.Sprintf("%v", key))
return uint64(hash) % cm.shardCount
}

func fnv32(key string) uint32 {
hash := uint32(2166136261)
const prime32 = uint32(16777619)
keyLength := len(key)
for i := 0; i < keyLength; i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
}
118 changes: 118 additions & 0 deletions maputil/concurrentmap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package maputil

import (
"fmt"
"sync"
"testing"

"github.com/duke-git/lancet/v2/internal"
)

func TestConcurrentMap_Set_Get(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_Set_Get")

cm := NewConcurrentMap[string, int](100)

var wg1 sync.WaitGroup
wg1.Add(10)

for i := 0; i < 10; i++ {
go func(n int) {
cm.Set(fmt.Sprintf("%d", n), n)
wg1.Done()
}(i)
}
wg1.Wait()

var wg2 sync.WaitGroup
wg2.Add(10)
for j := 0; j < 10; j++ {
go func(n int) {
val, ok := cm.Get(fmt.Sprintf("%d", n))
assert.Equal(n, val)
assert.Equal(true, ok)
wg2.Done()
}(j)
}
wg2.Wait()
}

func TestConcurrentMap_GetOrSet(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_GetOrSet")

cm := NewConcurrentMap[string, int](100)

for i := 0; i < 5; i++ {
go func(n int) {
val, ok := cm.GetOrSet(fmt.Sprintf("%d", n), n)
assert.Equal(n, val)
assert.Equal(false, ok)
}(i)
}

for j := 0; j < 5; j++ {
go func(n int) {
val, ok := cm.Get(fmt.Sprintf("%d", n))
assert.Equal(n, val)
assert.Equal(true, ok)
}(j)
}
}

func TestConcurrentMap_Delete(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_Delete")

cm := NewConcurrentMap[string, int](100)

var wg1 sync.WaitGroup
wg1.Add(10)

for i := 0; i < 10; i++ {
go func(n int) {
cm.Set(fmt.Sprintf("%d", n), n)
wg1.Done()
}(i)
}
wg1.Wait()

var wg2 sync.WaitGroup
wg2.Add(10)

for i := 0; i < 10; i++ {
go func(n int) {
cm.Delete(fmt.Sprintf("%d", n))
wg2.Done()
}(i)
}
wg2.Wait()

for j := 0; j < 10; j++ {
go func(n int) {
_, ok := cm.Get(fmt.Sprintf("%d", n))
assert.Equal(false, ok)
}(j)
}
}

func TestConcurrentMap_GetAndDelete(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_GetAndDelete")

cm := NewConcurrentMap[string, int](100)

for i := 0; i < 10; i++ {
go func(n int) {
cm.Set(fmt.Sprintf("%d", n), n)
}(i)
}

for j := 0; j < 10; j++ {
go func(n int) {
val, ok := cm.GetAndDelete(fmt.Sprintf("%d", n))
assert.Equal(n, val)
assert.Equal(true, ok)

_, ok = cm.Get(fmt.Sprintf("%d", n))
assert.Equal(false, ok)
}(j)
}
}

0 comments on commit fe0cb04

Please sign in to comment.