Skip to content

Commit 4dcb6cb

Browse files
authored
Merge pull request #6 from status-im/features/delete-skipped-messages
Change handling of skipped/deleted keys
2 parents 321788d + 7279c44 commit 4dcb6cb

File tree

9 files changed

+224
-121
lines changed

9 files changed

+224
-121
lines changed

keys_storage.go

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
package doubleratchet
22

3+
import (
4+
"bytes"
5+
"sort"
6+
)
7+
38
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
49
type KeysStorage interface {
510
// Get returns a message key by the given key and message number.
611
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
712

813
// Put saves the given mk under the specified key and msgNum.
9-
Put(k Key, msgNum uint, mk Key) error
14+
Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error
1015

1116
// DeleteMk ensures there's no message key under the specified key and msgNum.
1217
DeleteMk(k Key, msgNum uint) error
1318

14-
// DeletePk ensures there's no message keys under the specified key.
15-
DeletePk(k Key) error
19+
// DeleteOldMKeys deletes old message keys for a session.
20+
DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error
21+
22+
// TruncateMks truncates the number of keys to maxKeys.
23+
TruncateMks(sessionID []byte, maxKeys int) error
1624

1725
// Count returns number of message keys stored under the specified key.
1826
Count(k Key) (uint, error)
@@ -23,10 +31,10 @@ type KeysStorage interface {
2331

2432
// KeysStorageInMemory is an in-memory message keys storage.
2533
type KeysStorageInMemory struct {
26-
keys map[Key]map[uint]Key
34+
keys map[Key]map[uint]InMemoryKey
2735
}
2836

29-
// See KeysStorage.
37+
// Get returns a message key by the given key and message number.
3038
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
3139
if s.keys == nil {
3240
return Key{}, false, nil
@@ -39,22 +47,32 @@ func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
3947
if !ok {
4048
return Key{}, false, nil
4149
}
42-
return mk, true, nil
50+
return mk.messageKey, true, nil
51+
}
52+
53+
type InMemoryKey struct {
54+
messageKey Key
55+
seqNum uint
56+
sessionID []byte
4357
}
4458

45-
// See KeysStorage.
46-
func (s *KeysStorageInMemory) Put(pubKey Key, msgNum uint, mk Key) error {
59+
// Put saves the given mk under the specified key and msgNum.
60+
func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error {
4761
if s.keys == nil {
48-
s.keys = make(map[Key]map[uint]Key)
62+
s.keys = make(map[Key]map[uint]InMemoryKey)
4963
}
5064
if _, ok := s.keys[pubKey]; !ok {
51-
s.keys[pubKey] = make(map[uint]Key)
65+
s.keys[pubKey] = make(map[uint]InMemoryKey)
66+
}
67+
s.keys[pubKey][msgNum] = InMemoryKey{
68+
sessionID: sessionID,
69+
messageKey: mk,
70+
seqNum: seqNum,
5271
}
53-
s.keys[pubKey][msgNum] = mk
5472
return nil
5573
}
5674

57-
// See KeysStorage.
75+
// DeleteMk ensures there's no message key under the specified key and msgNum.
5876
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
5977
if s.keys == nil {
6078
return nil
@@ -72,27 +90,75 @@ func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
7290
return nil
7391
}
7492

75-
// See KeysStorage.
76-
func (s *KeysStorageInMemory) DeletePk(pubKey Key) error {
77-
if s.keys == nil {
78-
return nil
93+
// TruncateMks truncates the number of keys to maxKeys.
94+
func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error {
95+
var seqNos []uint
96+
// Collect all seq numbers
97+
for _, keys := range s.keys {
98+
for _, inMemoryKey := range keys {
99+
if bytes.Equal(inMemoryKey.sessionID, sessionID) {
100+
seqNos = append(seqNos, inMemoryKey.seqNum)
101+
}
102+
}
79103
}
80-
if _, ok := s.keys[pubKey]; !ok {
104+
105+
// Nothing to do if we haven't reached the limit
106+
if len(seqNos) <= maxKeys {
81107
return nil
82108
}
83-
delete(s.keys, pubKey)
109+
110+
// Take the sequence numbers we care about
111+
sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] })
112+
toDeleteSlice := seqNos[:len(seqNos)-maxKeys]
113+
114+
// Put in map for easier lookup
115+
toDelete := make(map[uint]bool)
116+
117+
for _, seqNo := range toDeleteSlice {
118+
toDelete[seqNo] = true
119+
}
120+
121+
for pubKey, keys := range s.keys {
122+
for i, inMemoryKey := range keys {
123+
if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) {
124+
delete(s.keys[pubKey], i)
125+
}
126+
}
127+
}
128+
84129
return nil
85130
}
86131

87-
// See KeysStorage.
132+
// DeleteOldMKeys deletes old message keys for a session.
133+
func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error {
134+
for pubKey, keys := range s.keys {
135+
for i, inMemoryKey := range keys {
136+
if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) {
137+
delete(s.keys[pubKey], i)
138+
}
139+
}
140+
}
141+
return nil
142+
}
143+
144+
// Count returns number of message keys stored under the specified key.
88145
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
89146
if s.keys == nil {
90147
return 0, nil
91148
}
92149
return uint(len(s.keys[pubKey])), nil
93150
}
94151

95-
// See KeysStorage.
152+
// All returns all the keys
96153
func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) {
97-
return s.keys, nil
154+
response := make(map[Key]map[uint]Key)
155+
156+
for pubKey, keys := range s.keys {
157+
response[pubKey] = make(map[uint]Key)
158+
for n, key := range keys {
159+
response[pubKey][n] = key.messageKey
160+
}
161+
}
162+
163+
return response, nil
98164
}

keys_storage_test.go

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TestKeysStorageInMemory_Put(t *testing.T) {
2929
ks := &KeysStorageInMemory{}
3030

3131
// Act and assert.
32-
err := ks.Put(pubKey1, 0, mk)
32+
err := ks.Put([]byte("session-id"), pubKey1, 0, mk, 1)
3333
require.NoError(t, err)
3434
}
3535

@@ -58,15 +58,9 @@ func TestKeysStorageInMemory_Flow(t *testing.T) {
5858
// Arrange.
5959
ks := &KeysStorageInMemory{}
6060

61-
t.Run("delete non-existent pubkey", func(t *testing.T) {
62-
// Act and assert.
63-
err := ks.DeletePk(pubKey1)
64-
require.NoError(t, err)
65-
})
66-
6761
t.Run("put and get existing", func(t *testing.T) {
6862
// Act.
69-
err := ks.Put(pubKey1, 0, mk)
63+
err := ks.Put([]byte("session-id"), pubKey1, 0, mk, 1)
7064
require.NoError(t, err)
7165

7266
k, ok, err := ks.Get(pubKey1, 0)
@@ -138,32 +132,4 @@ func TestKeysStorageInMemory_Flow(t *testing.T) {
138132
require.NoError(t, err)
139133
require.EqualValues(t, 0, cnt)
140134
})
141-
142-
t.Run("delete existing pubkey", func(t *testing.T) {
143-
// Act.
144-
err := ks.Put(pubKey1, 0, mk)
145-
require.NoError(t, err)
146-
147-
err = ks.Put(pubKey2, 0, mk)
148-
require.NoError(t, err)
149-
150-
err = ks.DeletePk(pubKey1)
151-
require.NoError(t, err)
152-
153-
err = ks.DeletePk(pubKey1)
154-
require.NoError(t, err)
155-
156-
err = ks.DeletePk(pubKey2)
157-
require.NoError(t, err)
158-
159-
cn1, err := ks.Count(pubKey1)
160-
require.NoError(t, err)
161-
162-
cn2, err := ks.Count(pubKey2)
163-
require.NoError(t, err)
164-
165-
// Assert.
166-
require.Empty(t, cn1)
167-
require.Empty(t, cn2)
168-
})
169135
}

options.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func WithMaxSkip(n int) option {
1717
}
1818
}
1919

20-
// WithMaxKeep specifies the maximum number of ratchet steps before a message is deleted.
20+
// WithMaxKeep specifies how long we keep message keys, counted in number of messages received
2121
// nolint: golint
2222
func WithMaxKeep(n int) option {
2323
return func(s *State) error {
@@ -29,6 +29,18 @@ func WithMaxKeep(n int) option {
2929
}
3030
}
3131

32+
// WithMaxMessageKeysPerSession specifies the maximum number of message keys per session
33+
// nolint: golint
34+
func WithMaxMessageKeysPerSession(n int) option {
35+
return func(s *State) error {
36+
if n < 0 {
37+
return fmt.Errorf("n must be non-negative")
38+
}
39+
s.MaxMessageKeysPerSession = n
40+
return nil
41+
}
42+
}
43+
3244
// WithKeysStorage replaces the default keys storage with the specified.
3345
// nolint: golint
3446
func WithKeysStorage(ks KeysStorage) option {

session.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,13 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
130130
)
131131

132132
// Is there a new ratchet key?
133-
isDHStepped := false
134133
if m.Header.DH != sc.DHr {
135134
if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil {
136135
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
137136
}
138137
if err = sc.dhRatchet(m.Header); err != nil {
139138
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
140139
}
141-
isDHStepped = true
142140
}
143141

144142
// After all, update the current chain.
@@ -151,18 +149,14 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
151149
return nil, fmt.Errorf("can't decrypt: %s", err)
152150
}
153151

152+
// Increment the number of keys
153+
sc.KeysCount++
154+
154155
// Apply changes.
155-
if err := s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
156+
if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil {
156157
return nil, err
157158
}
158159

159-
if isDHStepped {
160-
err = s.deleteSkippedKeys(s.DHr)
161-
if err != nil {
162-
return nil, err
163-
}
164-
}
165-
166160
// Store state
167161
if err := s.store(); err != nil {
168162
return nil, err

session_he.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,9 @@ func (s *sessionHE) RatchetDecrypt(m MessageHE, ad []byte) ([]byte, error) {
103103
return nil, fmt.Errorf("can't decrypt: %s", err)
104104
}
105105

106-
if err = s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
106+
if err = s.applyChanges(sc, []byte("FIXME"), append(skippedKeys1, skippedKeys2...)); err != nil {
107107
return nil, fmt.Errorf("failed to apply changes: %s", err)
108108
}
109-
if step {
110-
_ = s.deleteSkippedKeys(s.HKr)
111-
}
112109

113110
return plaintext, nil
114111
}

0 commit comments

Comments
 (0)