Skip to content

Commit c232de3

Browse files
committed
feat: simplify split api
1 parent 0307571 commit c232de3

File tree

5 files changed

+98
-302
lines changed

5 files changed

+98
-302
lines changed

src/asynch.rs

Lines changed: 48 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1+
use core::sync::atomic::{AtomicBool, Ordering};
2+
13
use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
24
use crate::common::decrypted_read_handler::DecryptedReadHandler;
35
use crate::connection::{decrypt_record, Handshake, State};
46
use crate::key_schedule::KeySchedule;
5-
use crate::key_schedule::{ReadKeySchedule, SharedState, WriteKeySchedule};
7+
use crate::key_schedule::{ReadKeySchedule, WriteKeySchedule};
68
use crate::read_buffer::ReadBuffer;
79
use crate::record::{ClientRecord, ClientRecordHeader};
8-
use crate::record_reader::RecordReader;
9-
use crate::split::{SplitState, SplitStateContainer};
10-
use crate::write_buffer::WriteBuffer;
10+
use crate::record_reader::{RecordReader, RecordReaderBorrowMut};
11+
use crate::write_buffer::{WriteBuffer, WriteBufferBorrowMut};
1112
use crate::TlsError;
1213
use embedded_io::Error as _;
1314
use embedded_io::ErrorType;
1415
use embedded_io_async::{BufRead, Read as AsyncRead, Write as AsyncWrite};
1516

1617
pub use crate::config::*;
17-
#[cfg(feature = "std")]
18-
pub use crate::split::ManagedSplitState;
19-
pub use crate::split::SplitConnectionState;
2018

2119
/// Type representing an async TLS connection. An instance of this type can
2220
/// be used to establish a TLS connection, write and read encrypted data over this connection,
@@ -27,7 +25,7 @@ where
2725
CipherSuite: TlsCipherSuite + 'static,
2826
{
2927
delegate: Socket,
30-
opened: bool,
28+
opened: AtomicBool,
3129
key_schedule: KeySchedule<CipherSuite>,
3230
record_reader: RecordReader<'a>,
3331
record_write_buf: WriteBuffer<'a>,
@@ -39,6 +37,9 @@ where
3937
Socket: AsyncRead + AsyncWrite + 'a,
4038
CipherSuite: TlsCipherSuite + 'static,
4139
{
40+
pub fn is_opened(&mut self) -> bool {
41+
*self.opened.get_mut()
42+
}
4243
/// Create a new TLS connection with the provided context and a async I/O implementation
4344
///
4445
/// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
@@ -57,7 +58,7 @@ where
5758
) -> Self {
5859
Self {
5960
delegate,
60-
opened: false,
61+
opened: AtomicBool::new(false),
6162
key_schedule: KeySchedule::new(),
6263
record_reader: RecordReader::new(record_read_buf),
6364
record_write_buf: WriteBuffer::new(record_write_buf),
@@ -101,7 +102,7 @@ where
101102
trace!("State {:?} -> {:?}", state, next_state);
102103
state = next_state;
103104
}
104-
self.opened = true;
105+
*self.opened.get_mut() = true;
105106

106107
Ok(())
107108
}
@@ -115,7 +116,7 @@ where
115116
///
116117
/// Returns the number of bytes buffered/written.
117118
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
118-
if self.opened {
119+
if self.is_opened() {
119120
if !self
120121
.record_write_buf
121122
.contains(ClientRecordHeader::ApplicationData)
@@ -179,7 +180,7 @@ where
179180

180181
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
181182
pub async fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
182-
if self.opened {
183+
if self.is_opened() {
183184
while self.decrypted.is_empty() {
184185
self.read_application_data().await?;
185186
}
@@ -200,7 +201,7 @@ where
200201
let mut handler = DecryptedReadHandler {
201202
source_buffer: buf_ptr_range,
202203
buffer_info: &mut self.decrypted,
203-
is_open: &mut self.opened,
204+
is_open: self.opened.get_mut(),
204205
};
205206
decrypt_record(
206207
self.key_schedule.read_state(),
@@ -215,9 +216,10 @@ where
215216
async fn close_internal(&mut self) -> Result<(), TlsError> {
216217
self.flush().await?;
217218

219+
let is_opened = self.is_opened();
218220
let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
219221
let slice = self.record_write_buf.write_record(
220-
&ClientRecord::close_notify(self.opened),
222+
&ClientRecord::close_notify(is_opened),
221223
write_key_schedule,
222224
Some(read_key_schedule),
223225
)?;
@@ -240,77 +242,33 @@ where
240242
}
241243
}
242244

243-
#[cfg(feature = "std")]
244-
pub fn split(
245-
self,
246-
) -> (
247-
TlsReader<'a, Socket, CipherSuite, ManagedSplitState>,
248-
TlsWriter<'a, Socket, CipherSuite, ManagedSplitState>,
249-
)
250-
where
251-
Socket: Clone,
252-
{
253-
self.split_with(ManagedSplitState::new())
254-
}
255-
256-
#[allow(clippy::type_complexity)] // Requires inherent type aliases to solve well.
257-
pub fn split_with<StateContainer>(
258-
self,
259-
state: StateContainer,
245+
pub fn split<'b>(
246+
&'b mut self,
260247
) -> (
261-
TlsReader<'a, Socket, CipherSuite, StateContainer::State>,
262-
TlsWriter<'a, Socket, CipherSuite, StateContainer::State>,
248+
TlsReader<'b, Socket, CipherSuite>,
249+
TlsWriter<'b, Socket, CipherSuite>,
263250
)
264251
where
265252
Socket: Clone,
266-
StateContainer: SplitStateContainer,
267253
{
268-
let state = state.state();
269-
state.set_open(self.opened);
270-
271-
let (shared, wks, rks) = self.key_schedule.split();
254+
let (wks, rks) = self.key_schedule.as_split();
272255

273256
let reader = TlsReader {
274-
state: state.clone(),
257+
opened: &self.opened,
275258
delegate: self.delegate.clone(),
276259
key_schedule: rks,
277-
record_reader: self.record_reader,
278-
decrypted: self.decrypted,
260+
record_reader: self.record_reader.reborrow_mut(),
261+
decrypted: &mut self.decrypted,
279262
};
280263
let writer = TlsWriter {
281-
state,
282-
delegate: self.delegate,
283-
key_schedule_shared: shared,
264+
opened: &self.opened,
265+
delegate: self.delegate.clone(),
284266
key_schedule: wks,
285-
record_write_buf: self.record_write_buf,
267+
record_write_buf: self.record_write_buf.reborrow_mut(),
286268
};
287269

288270
(reader, writer)
289271
}
290-
291-
pub fn unsplit<State>(
292-
reader: TlsReader<'a, Socket, CipherSuite, State>,
293-
writer: TlsWriter<'a, Socket, CipherSuite, State>,
294-
) -> Self
295-
where
296-
Socket: Clone,
297-
State: SplitState,
298-
{
299-
debug_assert!(reader.state.same(&writer.state));
300-
301-
TlsConnection {
302-
delegate: writer.delegate,
303-
opened: writer.state.is_open(),
304-
key_schedule: KeySchedule::unsplit(
305-
writer.key_schedule_shared,
306-
writer.key_schedule,
307-
reader.key_schedule,
308-
),
309-
record_reader: reader.record_reader,
310-
record_write_buf: writer.record_write_buf,
311-
decrypted: reader.decrypted,
312-
}
313-
}
314272
}
315273

316274
impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
@@ -359,18 +317,18 @@ where
359317
}
360318
}
361319

362-
pub struct TlsReader<'a, Socket, CipherSuite, State>
320+
pub struct TlsReader<'a, Socket, CipherSuite>
363321
where
364322
CipherSuite: TlsCipherSuite + 'static,
365323
{
366-
state: State,
324+
opened: &'a AtomicBool,
367325
delegate: Socket,
368-
key_schedule: ReadKeySchedule<CipherSuite>,
369-
record_reader: RecordReader<'a>,
370-
decrypted: DecryptedBufferInfo,
326+
key_schedule: &'a mut ReadKeySchedule<CipherSuite>,
327+
record_reader: RecordReaderBorrowMut<'a>,
328+
decrypted: &'a mut DecryptedBufferInfo,
371329
}
372330

373-
impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsReader<'a, Socket, CipherSuite, State>
331+
impl<'a, Socket, CipherSuite> AsRef<Socket> for TlsReader<'a, Socket, CipherSuite>
374332
where
375333
CipherSuite: TlsCipherSuite + 'static,
376334
{
@@ -379,19 +337,18 @@ where
379337
}
380338
}
381339

382-
impl<'a, Socket, CipherSuite, State> TlsReader<'a, Socket, CipherSuite, State>
340+
impl<'a, Socket, CipherSuite> TlsReader<'a, Socket, CipherSuite>
383341
where
384342
Socket: AsyncRead + 'a,
385343
CipherSuite: TlsCipherSuite + 'static,
386-
State: SplitState,
387344
{
388345
fn create_read_buffer(&mut self) -> ReadBuffer {
389346
self.decrypted.create_read_buffer(self.record_reader.buf)
390347
}
391348

392349
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
393350
pub async fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
394-
if self.state.is_open() {
351+
if self.opened.load(Ordering::Acquire) {
395352
while self.decrypted.is_empty() {
396353
self.read_application_data().await?;
397354
}
@@ -409,7 +366,7 @@ where
409366
.read(&mut self.delegate, &mut self.key_schedule)
410367
.await?;
411368

412-
let mut opened = self.state.is_open();
369+
let mut opened = self.opened.load(Ordering::Acquire);
413370
let mut handler = DecryptedReadHandler {
414371
source_buffer: buf_ptr_range,
415372
buffer_info: &mut self.decrypted,
@@ -420,24 +377,23 @@ where
420377
});
421378

422379
if !opened {
423-
self.state.set_open(false);
380+
self.opened.store(false, Ordering::Release);
424381
}
425382
result
426383
}
427384
}
428385

429-
pub struct TlsWriter<'a, Socket, CipherSuite, State>
386+
pub struct TlsWriter<'a, Socket, CipherSuite>
430387
where
431388
CipherSuite: TlsCipherSuite + 'static,
432389
{
433-
state: State,
390+
opened: &'a AtomicBool,
434391
delegate: Socket,
435-
key_schedule_shared: SharedState<CipherSuite>,
436-
key_schedule: WriteKeySchedule<CipherSuite>,
437-
record_write_buf: WriteBuffer<'a>,
392+
key_schedule: &'a mut WriteKeySchedule<CipherSuite>,
393+
record_write_buf: WriteBufferBorrowMut<'a>,
438394
}
439395

440-
impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsWriter<'a, Socket, CipherSuite, State>
396+
impl<'a, Socket, CipherSuite> AsRef<Socket> for TlsWriter<'a, Socket, CipherSuite>
441397
where
442398
CipherSuite: TlsCipherSuite + 'static,
443399
{
@@ -446,25 +402,24 @@ where
446402
}
447403
}
448404

449-
impl<'a, Socket, CipherSuite, State> ErrorType for TlsWriter<'a, Socket, CipherSuite, State>
405+
impl<'a, Socket, CipherSuite> ErrorType for TlsWriter<'a, Socket, CipherSuite>
450406
where
451407
CipherSuite: TlsCipherSuite + 'static,
452408
{
453409
type Error = TlsError;
454410
}
455411

456-
impl<'a, Socket, CipherSuite, State> ErrorType for TlsReader<'a, Socket, CipherSuite, State>
412+
impl<'a, Socket, CipherSuite> ErrorType for TlsReader<'a, Socket, CipherSuite>
457413
where
458414
CipherSuite: TlsCipherSuite + 'static,
459415
{
460416
type Error = TlsError;
461417
}
462418

463-
impl<'a, Socket, CipherSuite, State> AsyncRead for TlsReader<'a, Socket, CipherSuite, State>
419+
impl<'a, Socket, CipherSuite> AsyncRead for TlsReader<'a, Socket, CipherSuite>
464420
where
465421
Socket: AsyncRead + 'a,
466422
CipherSuite: TlsCipherSuite + 'static,
467-
State: SplitState,
468423
{
469424
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
470425
if buf.is_empty() {
@@ -479,11 +434,10 @@ where
479434
}
480435
}
481436

482-
impl<'a, Socket, CipherSuite, State> BufRead for TlsReader<'a, Socket, CipherSuite, State>
437+
impl<'a, Socket, CipherSuite> BufRead for TlsReader<'a, Socket, CipherSuite>
483438
where
484439
Socket: AsyncRead + 'a,
485440
CipherSuite: TlsCipherSuite + 'static,
486-
State: SplitState,
487441
{
488442
async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
489443
self.read_buffered().await.map(|mut buf| buf.peek_all())
@@ -494,14 +448,13 @@ where
494448
}
495449
}
496450

497-
impl<'a, Socket, CipherSuite, State> AsyncWrite for TlsWriter<'a, Socket, CipherSuite, State>
451+
impl<'a, Socket, CipherSuite> AsyncWrite for TlsWriter<'a, Socket, CipherSuite>
498452
where
499453
Socket: AsyncWrite + 'a,
500454
CipherSuite: TlsCipherSuite + 'static,
501-
State: SplitState,
502455
{
503456
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
504-
if self.state.is_open() {
457+
if self.opened.load(Ordering::Acquire) {
505458
if !self
506459
.record_write_buf
507460
.contains(ClientRecordHeader::ApplicationData)

0 commit comments

Comments
 (0)