diff --git a/arc.go b/arc.go index e2015e9..21dd68d 100644 --- a/arc.go +++ b/arc.go @@ -2,6 +2,7 @@ package gcache import ( "container/list" + "errors" "time" ) @@ -358,6 +359,11 @@ func (c *ARC) Len(checkExpired bool) int { return length } +// Increment an item +func (c *ARC) Increment(k interface{}, n int64) (interface{}, error) { + return nil, errors.New("method not implemented") +} + // Purge is used to completely clear the cache func (c *ARC) Purge() { c.mu.Lock() diff --git a/cache.go b/cache.go index e13e6f1..8d680b3 100644 --- a/cache.go +++ b/cache.go @@ -28,6 +28,7 @@ type Cache interface { Keys(checkExpired bool) []interface{} Len(checkExpired bool) int Has(key interface{}) bool + Increment(k interface{}, n int64) (interface{}, error) statsAccessor } diff --git a/examples/custom_expiration.go b/examples/custom_expiration.go index 54f12a6..571d121 100644 --- a/examples/custom_expiration.go +++ b/examples/custom_expiration.go @@ -6,7 +6,7 @@ import ( "time" ) -func main() { +func main2() { gc := gcache.New(10). LFU(). Build() diff --git a/examples/example.go b/examples/example.go index 97c1b32..fdc6951 100644 --- a/examples/example.go +++ b/examples/example.go @@ -5,7 +5,7 @@ import ( "github.com/bluele/gcache" ) -func main() { +func main3() { gc := gcache.New(10). LFU(). Build() diff --git a/incrementer.go b/incrementer.go new file mode 100644 index 0000000..020757a --- /dev/null +++ b/incrementer.go @@ -0,0 +1,92 @@ +package gcache + +import ( + "fmt" + "sync" +) + +type BaseIncrementer interface { + Increment(k interface{}, n int64) (interface{}, error) +} + +func incrementValue(v interface{}, n int64) (result interface{}, err error) { + switch v.(type) { + case int: + v = v.(int) + int(n) + case int8: + v = v.(int8) + int8(n) + case int16: + v = v.(int16) + int16(n) + case int32: + v = v.(int32) + int32(n) + case int64: + v = v.(int64) + n + case uint: + v = v.(uint) + uint(n) + case uintptr: + v = v.(uintptr) + uintptr(n) + case uint8: + v = v.(uint8) + uint8(n) + case uint16: + v = v.(uint16) + uint16(n) + case uint32: + v = v.(uint32) + uint32(n) + case uint64: + v = v.(uint64) + uint64(n) + case float32: + v = v.(float32) + float32(n) + case float64: + v = v.(float64) + float64(n) + default: + return nil, fmt.Errorf("the value %v is not an integer", v) + } + // restore original interface + return v.(interface{}), nil +} + +// make sure that LRUIncrementer implements BaseIncrementer +var _ BaseIncrementer = &LRUIncrementer{} + +type LRUIncrementer struct { + cache *LRUCache + lock sync.RWMutex +} + +func newLRUIncrementer(c *LRUCache) *LRUIncrementer { + i := &LRUIncrementer{cache: c} + return i +} + +func (i *LRUIncrementer) Increment(key interface{}, n int64) (interface{}, error) { + i.cache.mu.Lock() + + item, found := i.cache.items[key] + if !found { + i.cache.mu.Unlock() + return nil, KeyNotFoundError + } + + it := item.Value.(*lruItem) + if it.IsExpired(nil) { + i.cache.removeElement(item) + i.cache.mu.Unlock() + return nil, KeyNotFoundError + } + + v := it.value + + vNew, err := incrementValue(v, n) + if err != nil { + i.cache.mu.Unlock() + return nil, err + } + + _, err = i.cache.set(key, vNew) + if err != nil { + i.cache.mu.Unlock() + return nil, err + } + i.cache.mu.Unlock() + + return vNew, nil +} diff --git a/lfu.go b/lfu.go index f781a1f..5df7287 100644 --- a/lfu.go +++ b/lfu.go @@ -2,6 +2,7 @@ package gcache import ( "container/list" + "errors" "time" ) @@ -311,6 +312,11 @@ func (c *LFUCache) Len(checkExpired bool) int { return length } +// Increment an item +func (c *LFUCache) Increment(k interface{}, n int64) (interface{}, error) { + return nil, errors.New("method not implemented") +} + // Completely clear the cache func (c *LFUCache) Purge() { c.mu.Lock() diff --git a/lru.go b/lru.go index a85d660..36219ab 100644 --- a/lru.go +++ b/lru.go @@ -8,8 +8,9 @@ import ( // Discards the least recently used items first. type LRUCache struct { baseCache - items map[interface{}]*list.Element - evictList *list.List + items map[interface{}]*list.Element + evictList *list.List + incrementer *LRUIncrementer } func newLRUCache(cb *CacheBuilder) *LRUCache { @@ -18,6 +19,7 @@ func newLRUCache(cb *CacheBuilder) *LRUCache { c.init() c.loadGroup.cache = c + c.incrementer = newLRUIncrementer(c) return c } @@ -281,6 +283,11 @@ func (c *LRUCache) Len(checkExpired bool) int { return length } +// Increment an item +func (c *LRUCache) Increment(k interface{}, n int64) (interface{}, error) { + return c.incrementer.Increment(k, n) +} + // Completely clear the cache func (c *LRUCache) Purge() { c.mu.Lock() diff --git a/lru_test.go b/lru_test.go index fa9140f..18a4e00 100644 --- a/lru_test.go +++ b/lru_test.go @@ -21,8 +21,8 @@ func TestLoadingLRUGet(t *testing.T) { func TestLRULength(t *testing.T) { gc := buildTestLoadingCache(t, TYPE_LRU, 1000, loader) - gc.Get("test1") - gc.Get("test2") + _, _ = gc.Get("test1") + _, _ = gc.Get("test2") length := gc.Len(true) expectedLength := 2 if length != expectedLength { @@ -52,8 +52,8 @@ func TestLRUHas(t *testing.T) { for i := 0; i < 10; i++ { t.Run(fmt.Sprint(i), func(t *testing.T) { - gc.Get("test1") - gc.Get("test2") + _, _ = gc.Get("test1") + _, _ = gc.Get("test2") if gc.Has("test0") { t.Fatal("should not have test0") @@ -79,3 +79,179 @@ func TestLRUHas(t *testing.T) { }) } } + +func TestBasicLRUIncrementer(t *testing.T) { + gc := buildTestLoadingCacheWithExpiration(t, TYPE_LRU, 100, 10*time.Second) + defer gc.Purge() + + // integer + err := gc.Set("some-key", 1) + if err != nil { + t.Error(err) + t.Fatal() + } + v, err := gc.Increment("some-key", 1) + if err != nil { + t.Error(err) + t.Fatal() + } + if v == nil { + t.Error(fmt.Errorf("v is nil")) + t.Fatal() + } + vNew, ok := v.(int) + if !ok { + t.Error(fmt.Errorf("vNew is not int")) + t.Fatal() + } + if vNew != 2 { + t.Error("increment int failed") + t.Fatal() + } + vFromC, err := gc.Get("some-key") + if err != nil { + t.Error(err) + t.Fatal() + } + if vFromC != vNew { + t.Error(fmt.Errorf("increment in cache by int64 failed, v:%v, vNew:%v", vNew, vFromC)) + } +} + +func TestLRUIncrementer(t *testing.T) { + gc := buildTestLoadingCacheWithExpiration(t, TYPE_LRU, 100, 10*time.Second) + defer gc.Purge() + featureValues := []interface{}{ + []int{1, 2}, + []int8{1, 0, 5}, + []int16{1, 0, 5}, + []int32{1, 0, 5}, + []int64{1, 0, 5}, + []uint{1, 0, 5}, + []uintptr{1, 0, 5}, + []uint8{1, 0, 5}, + []uint16{1, 0, 5}, + []uint32{1, 0, 5}, + []uint64{1, 0, 5}, + []float32{1.5, 0.2, 5.4}, + []float64{1.5, 0.7, 5.9}, + } + + for _, values := range featureValues { + forEachValue(values, func(i int, v interface{}) { + incrementBy := int64(1) + var incrementResult interface{} + + switch v.(type) { + case int: + v = v.(int) + incrementResult = int(incrementBy) + v.(int) + case int8: + v = v.(int8) + incrementResult = int8(incrementBy) + v.(int8) + case int16: + v = v.(int16) + incrementResult = int16(incrementBy) + v.(int16) + case int32: + v = v.(int32) + incrementResult = int32(incrementBy) + v.(int32) + case int64: + v = v.(int64) + incrementResult = int64(incrementBy) + v.(int64) + case uint: + v = v.(uint) + incrementResult = uint(incrementBy) + v.(uint) + case uintptr: + v = v.(uintptr) + incrementResult = uintptr(incrementBy) + v.(uintptr) + case uint8: + v = v.(uint8) + incrementResult = uint8(incrementBy) + v.(uint8) + case uint16: + v = v.(uint16) + incrementResult = uint16(incrementBy) + v.(uint16) + case uint32: + v = v.(uint32) + incrementResult = uint32(incrementBy) + v.(uint32) + case uint64: + v = v.(uint64) + incrementResult = uint64(incrementBy) + v.(uint64) + case float32: + v = v.(float32) + incrementResult = float32(incrementBy) + v.(float32) + case float64: + v = v.(float64) + incrementResult = float64(incrementBy) + v.(float64) + default: + t.Error(fmt.Errorf("the value %v is not an integer", v)) + t.Fatal() + } + + err := gc.Set("some-key", v) + if err != nil { + t.Error(err) + t.Fatal() + } + + vNew, err := gc.Increment("some-key", incrementBy) + if err != nil { + t.Error(err) + t.Fatal() + } + + if vNew == nil { + t.Error(fmt.Errorf("v is nil")) + t.Fatal() + } + + switch vNew.(type) { + case int: + vNew = vNew.(int) + case int8: + vNew = vNew.(int8) + case int16: + vNew = vNew.(int16) + case int32: + vNew = vNew.(int32) + case int64: + vNew = vNew.(int64) + case uint: + vNew = vNew.(uint) + case uintptr: + vNew = vNew.(uintptr) + case uint8: + vNew = vNew.(uint8) + case uint16: + vNew = vNew.(uint16) + case uint32: + vNew = vNew.(uint32) + case uint64: + vNew = vNew.(uint64) + case float32: + vNew = vNew.(float32) + case float64: + vNew = vNew.(float64) + default: + t.Error(fmt.Errorf("the value %v is not an integer", vNew)) + t.Fatal() + } + + if vNew != incrementResult { + t.Error(fmt.Errorf("increment result by int64 failed, v:%v, vNew:%v", v, vNew)) + //t.Fatal() + } else { + t.Logf("value:%v, incremented:%v, by number: %d \n", v, vNew, incrementBy) + } + + vFromC, err := gc.Get("some-key") + if err != nil { + t.Error(err) + t.Fatal() + } + // ugly hack to compare different types + if vFromC != vNew { + t.Error(fmt.Errorf("increment in cache by int64 failed, v:%v, vNew:%v", vNew, vFromC)) + } + }) + } +} diff --git a/simple.go b/simple.go index 7310af1..15f1755 100644 --- a/simple.go +++ b/simple.go @@ -1,6 +1,7 @@ package gcache import ( + "errors" "time" ) @@ -274,6 +275,11 @@ func (c *SimpleCache) Len(checkExpired bool) int { return length } +// Increment an item +func (c *SimpleCache) Increment(k interface{}, n int64) (interface{}, error) { + return nil, errors.New("method not implemented") +} + // Completely clear the cache func (c *SimpleCache) Purge() { c.mu.Lock() diff --git a/utils.go b/utils.go index 1f784e4..6d7d3fe 100644 --- a/utils.go +++ b/utils.go @@ -1,5 +1,10 @@ package gcache +import ( + "fmt" + "reflect" +) + func minInt(x, y int) int { if x < y { return x @@ -13,3 +18,17 @@ func maxInt(x, y int) int { } return y } +func forEachValue(ifaceSlice interface{}, f func(i int, val interface{})) { + v := reflect.ValueOf(ifaceSlice) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + panic(fmt.Errorf("forEachValue: expected slice type, found %q", v.Kind().String())) + } + + for i := 0; i < v.Len(); i++ { + val := v.Index(i).Interface() + f(i, val) + } +}