1
1
package doubleratchet
2
2
3
+ import (
4
+ "bytes"
5
+ "sort"
6
+ )
7
+
3
8
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
4
9
type KeysStorage interface {
5
10
// Get returns a message key by the given key and message number.
6
11
Get (k Key , msgNum uint ) (mk Key , ok bool , err error )
7
12
8
13
// Put saves the given mk under the specified key and msgNum.
9
- Put (k Key , msgNum uint , mk Key ) error
14
+ Put (sessionID [] byte , k Key , msgNum uint , mk Key , keySeqNum uint ) error
10
15
11
16
// DeleteMk ensures there's no message key under the specified key and msgNum.
12
17
DeleteMk (k Key , msgNum uint ) error
13
18
14
- // DeletePk ensures there's no message keys under the specified key.
15
- DeletePk (k Key ) error
19
+ // DeleteOldMKeys deletes old message keys for a session.
20
+ DeleteOldMks (sessionID []byte , deleteUntilSeqKey uint ) error
21
+
22
+ // TruncateMks truncates the number of keys to maxKeys.
23
+ TruncateMks (sessionID []byte , maxKeys int ) error
16
24
17
25
// Count returns number of message keys stored under the specified key.
18
26
Count (k Key ) (uint , error )
@@ -23,10 +31,10 @@ type KeysStorage interface {
23
31
24
32
// KeysStorageInMemory is an in-memory message keys storage.
25
33
type KeysStorageInMemory struct {
26
- keys map [Key ]map [uint ]Key
34
+ keys map [Key ]map [uint ]InMemoryKey
27
35
}
28
36
29
- // See KeysStorage .
37
+ // Get returns a message key by the given key and message number .
30
38
func (s * KeysStorageInMemory ) Get (pubKey Key , msgNum uint ) (Key , bool , error ) {
31
39
if s .keys == nil {
32
40
return Key {}, false , nil
@@ -39,22 +47,32 @@ func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
39
47
if ! ok {
40
48
return Key {}, false , nil
41
49
}
42
- return mk , true , nil
50
+ return mk .messageKey , true , nil
51
+ }
52
+
53
+ type InMemoryKey struct {
54
+ messageKey Key
55
+ seqNum uint
56
+ sessionID []byte
43
57
}
44
58
45
- // See KeysStorage .
46
- func (s * KeysStorageInMemory ) Put (pubKey Key , msgNum uint , mk Key ) error {
59
+ // Put saves the given mk under the specified key and msgNum .
60
+ func (s * KeysStorageInMemory ) Put (sessionID [] byte , pubKey Key , msgNum uint , mk Key , seqNum uint ) error {
47
61
if s .keys == nil {
48
- s .keys = make (map [Key ]map [uint ]Key )
62
+ s .keys = make (map [Key ]map [uint ]InMemoryKey )
49
63
}
50
64
if _ , ok := s .keys [pubKey ]; ! ok {
51
- s .keys [pubKey ] = make (map [uint ]Key )
65
+ s .keys [pubKey ] = make (map [uint ]InMemoryKey )
66
+ }
67
+ s.keys [pubKey ][msgNum ] = InMemoryKey {
68
+ sessionID : sessionID ,
69
+ messageKey : mk ,
70
+ seqNum : seqNum ,
52
71
}
53
- s.keys [pubKey ][msgNum ] = mk
54
72
return nil
55
73
}
56
74
57
- // See KeysStorage .
75
+ // DeleteMk ensures there's no message key under the specified key and msgNum .
58
76
func (s * KeysStorageInMemory ) DeleteMk (pubKey Key , msgNum uint ) error {
59
77
if s .keys == nil {
60
78
return nil
@@ -72,27 +90,75 @@ func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
72
90
return nil
73
91
}
74
92
75
- // See KeysStorage.
76
- func (s * KeysStorageInMemory ) DeletePk (pubKey Key ) error {
77
- if s .keys == nil {
78
- return nil
93
+ // TruncateMks truncates the number of keys to maxKeys.
94
+ func (s * KeysStorageInMemory ) TruncateMks (sessionID []byte , maxKeys int ) error {
95
+ var seqNos []uint
96
+ // Collect all seq numbers
97
+ for _ , keys := range s .keys {
98
+ for _ , inMemoryKey := range keys {
99
+ if bytes .Equal (inMemoryKey .sessionID , sessionID ) {
100
+ seqNos = append (seqNos , inMemoryKey .seqNum )
101
+ }
102
+ }
79
103
}
80
- if _ , ok := s .keys [pubKey ]; ! ok {
104
+
105
+ // Nothing to do if we haven't reached the limit
106
+ if len (seqNos ) <= maxKeys {
81
107
return nil
82
108
}
83
- delete (s .keys , pubKey )
109
+
110
+ // Take the sequence numbers we care about
111
+ sort .Slice (seqNos , func (i , j int ) bool { return seqNos [i ] < seqNos [j ] })
112
+ toDeleteSlice := seqNos [:len (seqNos )- maxKeys ]
113
+
114
+ // Put in map for easier lookup
115
+ toDelete := make (map [uint ]bool )
116
+
117
+ for _ , seqNo := range toDeleteSlice {
118
+ toDelete [seqNo ] = true
119
+ }
120
+
121
+ for pubKey , keys := range s .keys {
122
+ for i , inMemoryKey := range keys {
123
+ if toDelete [inMemoryKey .seqNum ] && bytes .Equal (inMemoryKey .sessionID , sessionID ) {
124
+ delete (s .keys [pubKey ], i )
125
+ }
126
+ }
127
+ }
128
+
84
129
return nil
85
130
}
86
131
87
- // See KeysStorage.
132
+ // DeleteOldMKeys deletes old message keys for a session.
133
+ func (s * KeysStorageInMemory ) DeleteOldMks (sessionID []byte , deleteUntilSeqKey uint ) error {
134
+ for pubKey , keys := range s .keys {
135
+ for i , inMemoryKey := range keys {
136
+ if inMemoryKey .seqNum <= deleteUntilSeqKey && bytes .Equal (inMemoryKey .sessionID , sessionID ) {
137
+ delete (s .keys [pubKey ], i )
138
+ }
139
+ }
140
+ }
141
+ return nil
142
+ }
143
+
144
+ // Count returns number of message keys stored under the specified key.
88
145
func (s * KeysStorageInMemory ) Count (pubKey Key ) (uint , error ) {
89
146
if s .keys == nil {
90
147
return 0 , nil
91
148
}
92
149
return uint (len (s .keys [pubKey ])), nil
93
150
}
94
151
95
- // See KeysStorage.
152
+ // All returns all the keys
96
153
func (s * KeysStorageInMemory ) All () (map [Key ]map [uint ]Key , error ) {
97
- return s .keys , nil
154
+ response := make (map [Key ]map [uint ]Key )
155
+
156
+ for pubKey , keys := range s .keys {
157
+ response [pubKey ] = make (map [uint ]Key )
158
+ for n , key := range keys {
159
+ response [pubKey ][n ] = key .messageKey
160
+ }
161
+ }
162
+
163
+ return response , nil
98
164
}
0 commit comments