@@ -9,20 +9,25 @@ import (
9
9
"reflect"
10
10
"time"
11
11
12
+ "github.com/patrickmn/go-cache"
12
13
"github.com/revel/revel"
13
- "github.com/robfig/go-cache "
14
+ "sync "
14
15
)
15
16
16
17
type InMemoryCache struct {
17
- cache.Cache
18
+ cache cache.Cache // Only expose the methods we want to make available
19
+ mu sync.RWMutex // For increment / decrement prevent reads and writes
18
20
}
19
21
20
22
func NewInMemoryCache (defaultExpiration time.Duration ) InMemoryCache {
21
- return InMemoryCache {* cache .New (defaultExpiration , time .Minute )}
23
+ return InMemoryCache {cache : * cache .New (defaultExpiration , time .Minute ), mu : sync. RWMutex {} }
22
24
}
23
25
24
26
func (c InMemoryCache ) Get (key string , ptrValue interface {}) error {
25
- value , found := c .Cache .Get (key )
27
+ c .mu .RLock ()
28
+ defer c .mu .RUnlock ()
29
+
30
+ value , found := c .cache .Get (key )
26
31
if ! found {
27
32
return ErrCacheMiss
28
33
}
@@ -43,50 +48,118 @@ func (c InMemoryCache) GetMulti(keys ...string) (Getter, error) {
43
48
}
44
49
45
50
func (c InMemoryCache ) Set (key string , value interface {}, expires time.Duration ) error {
51
+ c .mu .Lock ()
52
+ defer c .mu .Unlock ()
46
53
// NOTE: go-cache understands the values of DefaultExpiryTime and ForEverNeverExpiry
47
- c .Cache .Set (key , value , expires )
54
+ c .cache .Set (key , value , expires )
48
55
return nil
49
56
}
50
57
51
58
func (c InMemoryCache ) Add (key string , value interface {}, expires time.Duration ) error {
52
- err := c .Cache .Add (key , value , expires )
53
- if err == cache .ErrKeyExists {
59
+ c .mu .Lock ()
60
+ defer c .mu .Unlock ()
61
+ err := c .cache .Add (key , value , expires )
62
+ if err != nil {
54
63
return ErrNotStored
55
64
}
56
65
return err
57
66
}
58
67
59
68
func (c InMemoryCache ) Replace (key string , value interface {}, expires time.Duration ) error {
60
- if err := c .Cache .Replace (key , value , expires ); err != nil {
69
+ c .mu .Lock ()
70
+ defer c .mu .Unlock ()
71
+ if err := c .cache .Replace (key , value , expires ); err != nil {
61
72
return ErrNotStored
62
73
}
63
74
return nil
64
75
}
65
76
66
77
func (c InMemoryCache ) Delete (key string ) error {
67
- if found := c .Cache .Delete (key ); ! found {
78
+ c .mu .RLock ()
79
+ defer c .mu .RUnlock ()
80
+ if _ , found := c .cache .Get (key ); ! found {
68
81
return ErrCacheMiss
69
82
}
83
+ c .cache .Delete (key )
70
84
return nil
71
85
}
72
86
73
87
func (c InMemoryCache ) Increment (key string , n uint64 ) (newValue uint64 , err error ) {
74
- newValue , err = c .Cache .Increment (key , n )
75
- if err == cache .ErrCacheMiss {
88
+ c .mu .Lock ()
89
+ defer c .mu .Unlock ()
90
+ if _ , found := c .cache .Get (key ); ! found {
76
91
return 0 , ErrCacheMiss
77
92
}
78
- return
93
+ if err = c .cache .Increment (key , int64 (n )); err != nil {
94
+ return
95
+ }
96
+
97
+ return c .convertTypeToUint64 (key )
79
98
}
80
99
81
100
func (c InMemoryCache ) Decrement (key string , n uint64 ) (newValue uint64 , err error ) {
82
- newValue , err = c .Cache .Decrement (key , n )
83
- if err == cache .ErrCacheMiss {
84
- return 0 , ErrCacheMiss
101
+ c .mu .Lock ()
102
+ defer c .mu .Unlock ()
103
+ if nv ,err := c .convertTypeToUint64 (key );err != nil {
104
+ return 0 , err
105
+ } else {
106
+ // Stop from going below zero
107
+ if n > nv {
108
+ n = nv
109
+ }
85
110
}
86
- return
111
+ if err = c .cache .Decrement (key , int64 (n )); err != nil {
112
+ return
113
+ }
114
+
115
+ return c .convertTypeToUint64 (key )
87
116
}
88
117
89
118
func (c InMemoryCache ) Flush () error {
90
- c .Cache .Flush ()
119
+ c .mu .Lock ()
120
+ defer c .mu .Unlock ()
121
+
122
+ c .cache .Flush ()
91
123
return nil
92
124
}
125
+
126
+ // Fetches and returns the converted type to a uint64
127
+ func (c InMemoryCache ) convertTypeToUint64 (key string ) (newValue uint64 , err error ) {
128
+ v , found := c .cache .Get (key )
129
+ if ! found {
130
+ return newValue , ErrCacheMiss
131
+ }
132
+
133
+ switch v .(type ) {
134
+ case int :
135
+ newValue = uint64 (v .(int ))
136
+ case int8 :
137
+ newValue = uint64 (v .(int8 ))
138
+ case int16 :
139
+ newValue = uint64 (v .(int16 ))
140
+ case int32 :
141
+ newValue = uint64 (v .(int32 ))
142
+ case int64 :
143
+ newValue = uint64 (v .(int64 ))
144
+ case uint :
145
+ newValue = uint64 (v .(uint ))
146
+ case uintptr :
147
+ newValue = uint64 (v .(uintptr ))
148
+ case uint8 :
149
+ newValue = uint64 (v .(uint8 ))
150
+ case uint16 :
151
+ newValue = uint64 (v .(uint16 ))
152
+ case uint32 :
153
+ newValue = uint64 (v .(uint32 ))
154
+ case uint64 :
155
+ newValue = uint64 (v .(uint64 ))
156
+ case float32 :
157
+ newValue = uint64 (v .(float32 ))
158
+ case float64 :
159
+ newValue = uint64 (v .(float64 ))
160
+ default :
161
+ err = ErrInvalidValue
162
+ }
163
+ return
164
+ }
165
+
0 commit comments