1
+ use core:: sync:: atomic:: { AtomicBool , Ordering } ;
2
+
1
3
use crate :: common:: decrypted_buffer_info:: DecryptedBufferInfo ;
2
4
use crate :: common:: decrypted_read_handler:: DecryptedReadHandler ;
3
5
use crate :: connection:: { decrypt_record, Handshake , State } ;
4
6
use crate :: key_schedule:: KeySchedule ;
5
- use crate :: key_schedule:: { ReadKeySchedule , SharedState , WriteKeySchedule } ;
7
+ use crate :: key_schedule:: { ReadKeySchedule , WriteKeySchedule } ;
6
8
use crate :: read_buffer:: ReadBuffer ;
7
9
use crate :: record:: { ClientRecord , ClientRecordHeader } ;
8
- use crate :: record_reader:: RecordReader ;
9
- use crate :: split:: { SplitState , SplitStateContainer } ;
10
- use crate :: write_buffer:: WriteBuffer ;
10
+ use crate :: record_reader:: { RecordReader , RecordReaderBorrowMut } ;
11
+ use crate :: write_buffer:: { WriteBuffer , WriteBufferBorrowMut } ;
11
12
use crate :: TlsError ;
12
13
use embedded_io:: Error as _;
13
14
use embedded_io:: ErrorType ;
14
15
use embedded_io_async:: { BufRead , Read as AsyncRead , Write as AsyncWrite } ;
15
16
16
17
pub use crate :: config:: * ;
17
- #[ cfg( feature = "std" ) ]
18
- pub use crate :: split:: ManagedSplitState ;
19
- pub use crate :: split:: SplitConnectionState ;
20
18
21
19
/// Type representing an async TLS connection. An instance of this type can
22
20
/// be used to establish a TLS connection, write and read encrypted data over this connection,
27
25
CipherSuite : TlsCipherSuite + ' static ,
28
26
{
29
27
delegate : Socket ,
30
- opened : bool ,
28
+ opened : AtomicBool ,
31
29
key_schedule : KeySchedule < CipherSuite > ,
32
30
record_reader : RecordReader < ' a > ,
33
31
record_write_buf : WriteBuffer < ' a > ,
39
37
Socket : AsyncRead + AsyncWrite + ' a ,
40
38
CipherSuite : TlsCipherSuite + ' static ,
41
39
{
40
+ pub fn is_opened ( & mut self ) -> bool {
41
+ * self . opened . get_mut ( )
42
+ }
42
43
/// Create a new TLS connection with the provided context and a async I/O implementation
43
44
///
44
45
/// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
57
58
) -> Self {
58
59
Self {
59
60
delegate,
60
- opened : false ,
61
+ opened : AtomicBool :: new ( false ) ,
61
62
key_schedule : KeySchedule :: new ( ) ,
62
63
record_reader : RecordReader :: new ( record_read_buf) ,
63
64
record_write_buf : WriteBuffer :: new ( record_write_buf) ,
@@ -101,7 +102,7 @@ where
101
102
trace ! ( "State {:?} -> {:?}" , state, next_state) ;
102
103
state = next_state;
103
104
}
104
- self . opened = true ;
105
+ * self . opened . get_mut ( ) = true ;
105
106
106
107
Ok ( ( ) )
107
108
}
@@ -115,7 +116,7 @@ where
115
116
///
116
117
/// Returns the number of bytes buffered/written.
117
118
pub async fn write ( & mut self , buf : & [ u8 ] ) -> Result < usize , TlsError > {
118
- if self . opened {
119
+ if self . is_opened ( ) {
119
120
if !self
120
121
. record_write_buf
121
122
. contains ( ClientRecordHeader :: ApplicationData )
@@ -179,7 +180,7 @@ where
179
180
180
181
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
181
182
pub async fn read_buffered ( & mut self ) -> Result < ReadBuffer , TlsError > {
182
- if self . opened {
183
+ if self . is_opened ( ) {
183
184
while self . decrypted . is_empty ( ) {
184
185
self . read_application_data ( ) . await ?;
185
186
}
@@ -200,7 +201,7 @@ where
200
201
let mut handler = DecryptedReadHandler {
201
202
source_buffer : buf_ptr_range,
202
203
buffer_info : & mut self . decrypted ,
203
- is_open : & mut self . opened ,
204
+ is_open : self . opened . get_mut ( ) ,
204
205
} ;
205
206
decrypt_record (
206
207
self . key_schedule . read_state ( ) ,
@@ -215,9 +216,10 @@ where
215
216
async fn close_internal ( & mut self ) -> Result < ( ) , TlsError > {
216
217
self . flush ( ) . await ?;
217
218
219
+ let is_opened = self . is_opened ( ) ;
218
220
let ( write_key_schedule, read_key_schedule) = self . key_schedule . as_split ( ) ;
219
221
let slice = self . record_write_buf . write_record (
220
- & ClientRecord :: close_notify ( self . opened ) ,
222
+ & ClientRecord :: close_notify ( is_opened ) ,
221
223
write_key_schedule,
222
224
Some ( read_key_schedule) ,
223
225
) ?;
@@ -240,77 +242,33 @@ where
240
242
}
241
243
}
242
244
243
- #[ cfg( feature = "std" ) ]
244
- pub fn split (
245
- self ,
246
- ) -> (
247
- TlsReader < ' a , Socket , CipherSuite , ManagedSplitState > ,
248
- TlsWriter < ' a , Socket , CipherSuite , ManagedSplitState > ,
249
- )
250
- where
251
- Socket : Clone ,
252
- {
253
- self . split_with ( ManagedSplitState :: new ( ) )
254
- }
255
-
256
- #[ allow( clippy:: type_complexity) ] // Requires inherent type aliases to solve well.
257
- pub fn split_with < StateContainer > (
258
- self ,
259
- state : StateContainer ,
245
+ pub fn split < ' b > (
246
+ & ' b mut self ,
260
247
) -> (
261
- TlsReader < ' a , Socket , CipherSuite , StateContainer :: State > ,
262
- TlsWriter < ' a , Socket , CipherSuite , StateContainer :: State > ,
248
+ TlsReader < ' b , Socket , CipherSuite > ,
249
+ TlsWriter < ' b , Socket , CipherSuite > ,
263
250
)
264
251
where
265
252
Socket : Clone ,
266
- StateContainer : SplitStateContainer ,
267
253
{
268
- let state = state. state ( ) ;
269
- state. set_open ( self . opened ) ;
270
-
271
- let ( shared, wks, rks) = self . key_schedule . split ( ) ;
254
+ let ( wks, rks) = self . key_schedule . as_split ( ) ;
272
255
273
256
let reader = TlsReader {
274
- state : state . clone ( ) ,
257
+ opened : & self . opened ,
275
258
delegate : self . delegate . clone ( ) ,
276
259
key_schedule : rks,
277
- record_reader : self . record_reader ,
278
- decrypted : self . decrypted ,
260
+ record_reader : self . record_reader . reborrow_mut ( ) ,
261
+ decrypted : & mut self . decrypted ,
279
262
} ;
280
263
let writer = TlsWriter {
281
- state,
282
- delegate : self . delegate ,
283
- key_schedule_shared : shared,
264
+ opened : & self . opened ,
265
+ delegate : self . delegate . clone ( ) ,
284
266
key_schedule : wks,
285
- record_write_buf : self . record_write_buf ,
267
+ record_write_buf : self . record_write_buf . reborrow_mut ( ) ,
286
268
} ;
287
269
288
270
( reader, writer)
289
271
}
290
-
291
- pub fn unsplit < State > (
292
- reader : TlsReader < ' a , Socket , CipherSuite , State > ,
293
- writer : TlsWriter < ' a , Socket , CipherSuite , State > ,
294
- ) -> Self
295
- where
296
- Socket : Clone ,
297
- State : SplitState ,
298
- {
299
- debug_assert ! ( reader. state. same( & writer. state) ) ;
300
-
301
- TlsConnection {
302
- delegate : writer. delegate ,
303
- opened : writer. state . is_open ( ) ,
304
- key_schedule : KeySchedule :: unsplit (
305
- writer. key_schedule_shared ,
306
- writer. key_schedule ,
307
- reader. key_schedule ,
308
- ) ,
309
- record_reader : reader. record_reader ,
310
- record_write_buf : writer. record_write_buf ,
311
- decrypted : reader. decrypted ,
312
- }
313
- }
314
272
}
315
273
316
274
impl < ' a , Socket , CipherSuite > ErrorType for TlsConnection < ' a , Socket , CipherSuite >
@@ -359,18 +317,18 @@ where
359
317
}
360
318
}
361
319
362
- pub struct TlsReader < ' a , Socket , CipherSuite , State >
320
+ pub struct TlsReader < ' a , Socket , CipherSuite >
363
321
where
364
322
CipherSuite : TlsCipherSuite + ' static ,
365
323
{
366
- state : State ,
324
+ opened : & ' a AtomicBool ,
367
325
delegate : Socket ,
368
- key_schedule : ReadKeySchedule < CipherSuite > ,
369
- record_reader : RecordReader < ' a > ,
370
- decrypted : DecryptedBufferInfo ,
326
+ key_schedule : & ' a mut ReadKeySchedule < CipherSuite > ,
327
+ record_reader : RecordReaderBorrowMut < ' a > ,
328
+ decrypted : & ' a mut DecryptedBufferInfo ,
371
329
}
372
330
373
- impl < ' a , Socket , CipherSuite , State > AsRef < Socket > for TlsReader < ' a , Socket , CipherSuite , State >
331
+ impl < ' a , Socket , CipherSuite > AsRef < Socket > for TlsReader < ' a , Socket , CipherSuite >
374
332
where
375
333
CipherSuite : TlsCipherSuite + ' static ,
376
334
{
@@ -379,19 +337,18 @@ where
379
337
}
380
338
}
381
339
382
- impl < ' a , Socket , CipherSuite , State > TlsReader < ' a , Socket , CipherSuite , State >
340
+ impl < ' a , Socket , CipherSuite > TlsReader < ' a , Socket , CipherSuite >
383
341
where
384
342
Socket : AsyncRead + ' a ,
385
343
CipherSuite : TlsCipherSuite + ' static ,
386
- State : SplitState ,
387
344
{
388
345
fn create_read_buffer ( & mut self ) -> ReadBuffer {
389
346
self . decrypted . create_read_buffer ( self . record_reader . buf )
390
347
}
391
348
392
349
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
393
350
pub async fn read_buffered ( & mut self ) -> Result < ReadBuffer , TlsError > {
394
- if self . state . is_open ( ) {
351
+ if self . opened . load ( Ordering :: Acquire ) {
395
352
while self . decrypted . is_empty ( ) {
396
353
self . read_application_data ( ) . await ?;
397
354
}
@@ -409,7 +366,7 @@ where
409
366
. read ( & mut self . delegate , & mut self . key_schedule )
410
367
. await ?;
411
368
412
- let mut opened = self . state . is_open ( ) ;
369
+ let mut opened = self . opened . load ( Ordering :: Acquire ) ;
413
370
let mut handler = DecryptedReadHandler {
414
371
source_buffer : buf_ptr_range,
415
372
buffer_info : & mut self . decrypted ,
@@ -420,24 +377,23 @@ where
420
377
} ) ;
421
378
422
379
if !opened {
423
- self . state . set_open ( false ) ;
380
+ self . opened . store ( false , Ordering :: Release ) ;
424
381
}
425
382
result
426
383
}
427
384
}
428
385
429
- pub struct TlsWriter < ' a , Socket , CipherSuite , State >
386
+ pub struct TlsWriter < ' a , Socket , CipherSuite >
430
387
where
431
388
CipherSuite : TlsCipherSuite + ' static ,
432
389
{
433
- state : State ,
390
+ opened : & ' a AtomicBool ,
434
391
delegate : Socket ,
435
- key_schedule_shared : SharedState < CipherSuite > ,
436
- key_schedule : WriteKeySchedule < CipherSuite > ,
437
- record_write_buf : WriteBuffer < ' a > ,
392
+ key_schedule : & ' a mut WriteKeySchedule < CipherSuite > ,
393
+ record_write_buf : WriteBufferBorrowMut < ' a > ,
438
394
}
439
395
440
- impl < ' a , Socket , CipherSuite , State > AsRef < Socket > for TlsWriter < ' a , Socket , CipherSuite , State >
396
+ impl < ' a , Socket , CipherSuite > AsRef < Socket > for TlsWriter < ' a , Socket , CipherSuite >
441
397
where
442
398
CipherSuite : TlsCipherSuite + ' static ,
443
399
{
@@ -446,25 +402,24 @@ where
446
402
}
447
403
}
448
404
449
- impl < ' a , Socket , CipherSuite , State > ErrorType for TlsWriter < ' a , Socket , CipherSuite , State >
405
+ impl < ' a , Socket , CipherSuite > ErrorType for TlsWriter < ' a , Socket , CipherSuite >
450
406
where
451
407
CipherSuite : TlsCipherSuite + ' static ,
452
408
{
453
409
type Error = TlsError ;
454
410
}
455
411
456
- impl < ' a , Socket , CipherSuite , State > ErrorType for TlsReader < ' a , Socket , CipherSuite , State >
412
+ impl < ' a , Socket , CipherSuite > ErrorType for TlsReader < ' a , Socket , CipherSuite >
457
413
where
458
414
CipherSuite : TlsCipherSuite + ' static ,
459
415
{
460
416
type Error = TlsError ;
461
417
}
462
418
463
- impl < ' a , Socket , CipherSuite , State > AsyncRead for TlsReader < ' a , Socket , CipherSuite , State >
419
+ impl < ' a , Socket , CipherSuite > AsyncRead for TlsReader < ' a , Socket , CipherSuite >
464
420
where
465
421
Socket : AsyncRead + ' a ,
466
422
CipherSuite : TlsCipherSuite + ' static ,
467
- State : SplitState ,
468
423
{
469
424
async fn read ( & mut self , buf : & mut [ u8 ] ) -> Result < usize , Self :: Error > {
470
425
if buf. is_empty ( ) {
@@ -479,11 +434,10 @@ where
479
434
}
480
435
}
481
436
482
- impl < ' a , Socket , CipherSuite , State > BufRead for TlsReader < ' a , Socket , CipherSuite , State >
437
+ impl < ' a , Socket , CipherSuite > BufRead for TlsReader < ' a , Socket , CipherSuite >
483
438
where
484
439
Socket : AsyncRead + ' a ,
485
440
CipherSuite : TlsCipherSuite + ' static ,
486
- State : SplitState ,
487
441
{
488
442
async fn fill_buf ( & mut self ) -> Result < & [ u8 ] , Self :: Error > {
489
443
self . read_buffered ( ) . await . map ( |mut buf| buf. peek_all ( ) )
@@ -494,14 +448,13 @@ where
494
448
}
495
449
}
496
450
497
- impl < ' a , Socket , CipherSuite , State > AsyncWrite for TlsWriter < ' a , Socket , CipherSuite , State >
451
+ impl < ' a , Socket , CipherSuite > AsyncWrite for TlsWriter < ' a , Socket , CipherSuite >
498
452
where
499
453
Socket : AsyncWrite + ' a ,
500
454
CipherSuite : TlsCipherSuite + ' static ,
501
- State : SplitState ,
502
455
{
503
456
async fn write ( & mut self , buf : & [ u8 ] ) -> Result < usize , Self :: Error > {
504
- if self . state . is_open ( ) {
457
+ if self . opened . load ( Ordering :: Acquire ) {
505
458
if !self
506
459
. record_write_buf
507
460
. contains ( ClientRecordHeader :: ApplicationData )
0 commit comments