From fe0cb04137a7d70a270ca13bd6ffa6031195131b Mon Sep 17 00:00:00 2001 From: dudaodong Date: Mon, 24 Jul 2023 17:10:45 +0800 Subject: [PATCH] feat: add ConcurrentMap --- maputil/concurrentmap.go | 154 ++++++++++++++++++++++++++++++++++ maputil/concurrentmap_test.go | 118 ++++++++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 maputil/concurrentmap.go create mode 100644 maputil/concurrentmap_test.go diff --git a/maputil/concurrentmap.go b/maputil/concurrentmap.go new file mode 100644 index 00000000..633370d9 --- /dev/null +++ b/maputil/concurrentmap.go @@ -0,0 +1,154 @@ +// Copyright 2021 dudaodong@gmail.com. All rights reserved. +// Use of this source code is governed by MIT license + +// Package maputil includes some functions to manipulate map. +package maputil + +import ( + "fmt" + "sync" +) + +const defaultShardCount = 32 + +// ConcurrentMap is like map, but is safe for concurrent use by multiple goroutines. +type ConcurrentMap[K comparable, V any] struct { + shardCount uint64 + locks []sync.RWMutex + maps []map[K]V +} + +// NewConcurrentMap create a ConcurrentMap with specific shard count. +func NewConcurrentMap[K comparable, V any](shardCount int) *ConcurrentMap[K, V] { + if shardCount <= 0 { + shardCount = defaultShardCount + } + + cm := &ConcurrentMap[K, V]{ + shardCount: uint64(shardCount), + locks: make([]sync.RWMutex, shardCount), + maps: make([]map[K]V, shardCount), + } + + for i := range cm.maps { + cm.maps[i] = make(map[K]V) + } + + return cm +} + +// Set the value for a key. +// Play: todo +func (cm *ConcurrentMap[K, V]) Set(key K, value V) { + shard := cm.getShard(key) + + cm.locks[shard].Lock() + cm.maps[shard][key] = value + + cm.locks[shard].Unlock() +} + +// Get the value stored in the map for a key, or nil if no. +// Play: todo +func (cm *ConcurrentMap[K, V]) Get(key K) (V, bool) { + shard := cm.getShard(key) + + cm.locks[shard].RLock() + value, ok := cm.maps[shard][key] + cm.locks[shard].RUnlock() + + return value, ok +} + +// GetOrSet returns the existing value for the key if present. +// Otherwise, it sets and returns the given value. +// Play: todo +func (cm *ConcurrentMap[K, V]) GetOrSet(key K, value V) (actual V, ok bool) { + shard := cm.getShard(key) + + cm.locks[shard].RLock() + if actual, ok := cm.maps[shard][key]; ok { + cm.locks[shard].RUnlock() + return actual, ok + } + cm.locks[shard].RUnlock() + + // lock again + cm.locks[shard].Lock() + if actual, ok = cm.maps[shard][key]; ok { + cm.locks[shard].Unlock() + return + } + + cm.maps[shard][key] = value + cm.locks[shard].Unlock() + + return value, ok +} + +// Delete the value for a key. +// Play: todo +func (cm *ConcurrentMap[K, V]) Delete(key K) { + shard := cm.getShard(key) + + cm.locks[shard].Lock() + delete(cm.maps[shard], key) + cm.locks[shard].Unlock() +} + +// GetAndDelete returns the existing value for the key if present and then delete the value for the key. +// Otherwise, do nothing, just return false +// Play: todo +func (cm *ConcurrentMap[K, V]) GetAndDelete(key K) (actual V, ok bool) { + shard := cm.getShard(key) + + cm.locks[shard].RLock() + if actual, ok = cm.maps[shard][key]; ok { + cm.locks[shard].RUnlock() + cm.Delete(key) + return + } + cm.locks[shard].RUnlock() + + return actual, false +} + +// Has checks if map has the value for a key. +// Play: todo +func (cm *ConcurrentMap[K, V]) Has(key K) bool { + _, ok := cm.Get(key) + return ok +} + +// Range calls iterator sequentially for each key and value present in each of the shards in the map. +// If iterator returns false, range stops the iteration. +func (cm *ConcurrentMap[K, V]) Range(iterator func(key K, value V) bool) { + for shard := range cm.locks { + cm.locks[shard].RLock() + + for k, v := range cm.maps[shard] { + if !iterator(k, v) { + cm.locks[shard].RUnlock() + return + } + } + cm.locks[shard].RUnlock() + } +} + +// getShard get shard by a key. +func (cm *ConcurrentMap[K, V]) getShard(key K) uint64 { + hash := fnv32(fmt.Sprintf("%v", key)) + return uint64(hash) % cm.shardCount +} + +func fnv32(key string) uint32 { + hash := uint32(2166136261) + const prime32 = uint32(16777619) + keyLength := len(key) + for i := 0; i < keyLength; i++ { + hash *= prime32 + hash ^= uint32(key[i]) + } + return hash +} diff --git a/maputil/concurrentmap_test.go b/maputil/concurrentmap_test.go new file mode 100644 index 00000000..4557550e --- /dev/null +++ b/maputil/concurrentmap_test.go @@ -0,0 +1,118 @@ +package maputil + +import ( + "fmt" + "sync" + "testing" + + "github.com/duke-git/lancet/v2/internal" +) + +func TestConcurrentMap_Set_Get(t *testing.T) { + assert := internal.NewAssert(t, "TestConcurrentMap_Set_Get") + + cm := NewConcurrentMap[string, int](100) + + var wg1 sync.WaitGroup + wg1.Add(10) + + for i := 0; i < 10; i++ { + go func(n int) { + cm.Set(fmt.Sprintf("%d", n), n) + wg1.Done() + }(i) + } + wg1.Wait() + + var wg2 sync.WaitGroup + wg2.Add(10) + for j := 0; j < 10; j++ { + go func(n int) { + val, ok := cm.Get(fmt.Sprintf("%d", n)) + assert.Equal(n, val) + assert.Equal(true, ok) + wg2.Done() + }(j) + } + wg2.Wait() +} + +func TestConcurrentMap_GetOrSet(t *testing.T) { + assert := internal.NewAssert(t, "TestConcurrentMap_GetOrSet") + + cm := NewConcurrentMap[string, int](100) + + for i := 0; i < 5; i++ { + go func(n int) { + val, ok := cm.GetOrSet(fmt.Sprintf("%d", n), n) + assert.Equal(n, val) + assert.Equal(false, ok) + }(i) + } + + for j := 0; j < 5; j++ { + go func(n int) { + val, ok := cm.Get(fmt.Sprintf("%d", n)) + assert.Equal(n, val) + assert.Equal(true, ok) + }(j) + } +} + +func TestConcurrentMap_Delete(t *testing.T) { + assert := internal.NewAssert(t, "TestConcurrentMap_Delete") + + cm := NewConcurrentMap[string, int](100) + + var wg1 sync.WaitGroup + wg1.Add(10) + + for i := 0; i < 10; i++ { + go func(n int) { + cm.Set(fmt.Sprintf("%d", n), n) + wg1.Done() + }(i) + } + wg1.Wait() + + var wg2 sync.WaitGroup + wg2.Add(10) + + for i := 0; i < 10; i++ { + go func(n int) { + cm.Delete(fmt.Sprintf("%d", n)) + wg2.Done() + }(i) + } + wg2.Wait() + + for j := 0; j < 10; j++ { + go func(n int) { + _, ok := cm.Get(fmt.Sprintf("%d", n)) + assert.Equal(false, ok) + }(j) + } +} + +func TestConcurrentMap_GetAndDelete(t *testing.T) { + assert := internal.NewAssert(t, "TestConcurrentMap_GetAndDelete") + + cm := NewConcurrentMap[string, int](100) + + for i := 0; i < 10; i++ { + go func(n int) { + cm.Set(fmt.Sprintf("%d", n), n) + }(i) + } + + for j := 0; j < 10; j++ { + go func(n int) { + val, ok := cm.GetAndDelete(fmt.Sprintf("%d", n)) + assert.Equal(n, val) + assert.Equal(true, ok) + + _, ok = cm.Get(fmt.Sprintf("%d", n)) + assert.Equal(false, ok) + }(j) + } +}