Skip to content

Commit 22a9d8a

Browse files
bifurcationjared2501
authored andcommitted
Merge pull request #177 from ekr/enhance_dtls2
Enhance dtls2
2 parents 30a67d8 + 77b0741 commit 22a9d8a

22 files changed

+1703
-491
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns
1818
with earlier TLS versions. However, unnecessary parts will be ruthlessly cut
1919
off.
2020

21+
## DTLS Support
22+
23+
Mint has partial support for DTLS, but that support is not yet complete
24+
and may still contain serious defects.
25+
26+
2127
## Quickstart
2228

2329
Installation is the same as for any other Go package:

client-state-machine.go

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ type clientStateStart struct {
5858
cookie []byte
5959
firstClientHello *HandshakeMessage
6060
helloRetryRequest *HandshakeMessage
61-
hsCtx HandshakeContext
61+
hsCtx *HandshakeContext
6262
}
6363

6464
var _ HandshakeState = &clientStateStart{}
@@ -172,8 +172,10 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
172172
}
173173
ch.CipherSuites = compatibleSuites
174174

175+
// TODO([email protected]): Check that the ticket can be used for early
176+
// data.
175177
// Signal early data if we're going to do it
176-
if len(state.Opts.EarlyData) > 0 {
178+
if state.Config.AllowEarlyData && state.helloRetryRequest == nil {
177179
state.Params.ClientSendingEarlyData = true
178180
ed = &EarlyDataExtension{}
179181
err = ch.Extensions.Add(ed)
@@ -255,9 +257,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
255257
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
256258
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
257259
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
258-
} else if len(state.Opts.EarlyData) > 0 {
259-
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
260-
return nil, nil, AlertInternalError
261260
} else {
262261
clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch)
263262
if err != nil {
@@ -291,7 +290,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
291290
if state.Params.ClientSendingEarlyData {
292291
toSend = append(toSend, []HandshakeAction{
293292
RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys},
294-
SendEarlyData{},
295293
}...)
296294
}
297295

@@ -302,7 +300,7 @@ type clientStateWaitSH struct {
302300
Config *Config
303301
Opts ConnectionOptions
304302
Params ConnectionParameters
305-
hsCtx HandshakeContext
303+
hsCtx *HandshakeContext
306304
OfferedDH map[NamedGroup][]byte
307305
OfferedPSK PreSharedKey
308306
PSK []byte
@@ -412,6 +410,11 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
412410
body: h.Sum(nil),
413411
}
414412

413+
state.hsCtx.receivedEndOfFlight()
414+
415+
// TODO([email protected]): Need to rekey with cleartext if we are on 0-RTT
416+
// mode. In DTLS, we also need to bump the sequence number.
417+
// This is a pre-existing defect in Mint. Issue #175.
415418
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
416419
return clientStateStart{
417420
Config: state.Config,
@@ -420,7 +423,7 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
420423
cookie: serverCookie.Cookie,
421424
firstClientHello: firstClientHello,
422425
helloRetryRequest: hm,
423-
}, nil, AlertNoAlert
426+
}, []HandshakeAction{ResetOut{1}}, AlertNoAlert
424427
}
425428

426429
// This is SH.
@@ -515,7 +518,6 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
515518
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
516519

517520
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
518-
519521
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
520522
nextState := clientStateWaitEE{
521523
Config: state.Config,
@@ -530,13 +532,20 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
530532
toSend := []HandshakeAction{
531533
RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys},
532534
}
535+
// We're definitely not going to have to send anything with
536+
// early data.
537+
if !state.Params.ClientSendingEarlyData {
538+
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
539+
KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)})
540+
}
541+
533542
return nextState, toSend, AlertNoAlert
534543
}
535544

536545
type clientStateWaitEE struct {
537546
Config *Config
538547
Params ConnectionParameters
539-
hsCtx HandshakeContext
548+
hsCtx *HandshakeContext
540549
cryptoParams CipherSuiteParams
541550
handshakeHash hash.Hash
542551
masterSecret []byte
@@ -596,6 +605,14 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
596605

597606
state.handshakeHash.Write(hm.Marshal())
598607

608+
toSend := []HandshakeAction{}
609+
610+
if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData {
611+
// We didn't get 0-RTT, so rekey to handshake.
612+
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
613+
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
614+
}
615+
599616
if state.Params.UsingPSK {
600617
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
601618
nextState := clientStateWaitFinished{
@@ -608,7 +625,7 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
608625
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
609626
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
610627
}
611-
return nextState, nil, AlertNoAlert
628+
return nextState, toSend, AlertNoAlert
612629
}
613630

614631
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
@@ -622,13 +639,13 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
622639
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
623640
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
624641
}
625-
return nextState, nil, AlertNoAlert
642+
return nextState, toSend, AlertNoAlert
626643
}
627644

628645
type clientStateWaitCertCR struct {
629646
Config *Config
630647
Params ConnectionParameters
631-
hsCtx HandshakeContext
648+
hsCtx *HandshakeContext
632649
cryptoParams CipherSuiteParams
633650
handshakeHash hash.Hash
634651
masterSecret []byte
@@ -706,7 +723,7 @@ func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta
706723
type clientStateWaitCert struct {
707724
Config *Config
708725
Params ConnectionParameters
709-
hsCtx HandshakeContext
726+
hsCtx *HandshakeContext
710727
cryptoParams CipherSuiteParams
711728
handshakeHash hash.Hash
712729

@@ -760,7 +777,7 @@ func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
760777
type clientStateWaitCV struct {
761778
Config *Config
762779
Params ConnectionParameters
763-
hsCtx HandshakeContext
780+
hsCtx *HandshakeContext
764781
cryptoParams CipherSuiteParams
765782
handshakeHash hash.Hash
766783

@@ -861,7 +878,7 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
861878

862879
type clientStateWaitFinished struct {
863880
Params ConnectionParameters
864-
hsCtx HandshakeContext
881+
hsCtx *HandshakeContext
865882
cryptoParams CipherSuiteParams
866883
handshakeHash hash.Hash
867884

@@ -933,6 +950,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
933950
toSend := []HandshakeAction{}
934951

935952
if state.Params.UsingEarlyData {
953+
logf(logTypeHandshake, "Sending end of early data")
936954
// Note: We only send EOED if the server is actually going to use the early
937955
// data. Otherwise, it will never see it, and the transcripts will
938956
// mismatch.
@@ -942,10 +960,11 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
942960

943961
state.handshakeHash.Write(eoedm.Marshal())
944962
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
945-
}
946963

947-
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
948-
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys})
964+
// And then rekey to handshake
965+
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
966+
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
967+
}
949968

950969
if state.Params.UsingClientAuth {
951970
// Extract constraints from certicateRequest
@@ -1045,6 +1064,8 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
10451064
RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys},
10461065
}...)
10471066

1067+
state.hsCtx.receivedEndOfFlight()
1068+
10481069
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
10491070
nextState := stateConnected{
10501071
Params: state.Params,

common.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ const (
2525
RecordTypeAlert RecordType = 21
2626
RecordTypeHandshake RecordType = 22
2727
RecordTypeApplicationData RecordType = 23
28+
RecordTypeAck RecordType = 25
2829
)
2930

3031
// enum {...} HandshakeType;
@@ -166,6 +167,8 @@ const (
166167
type State uint8
167168

168169
const (
170+
StateInit = 0
171+
169172
// states valid for the client
170173
StateClientStart State = iota
171174
StateClientWaitSH
@@ -179,6 +182,7 @@ const (
179182
StateServerStart State = iota
180183
StateServerRecvdCH
181184
StateServerNegotiated
185+
StateServerReadPastEarlyData
182186
StateServerWaitEOED
183187
StateServerWaitFlight2
184188
StateServerWaitCert
@@ -211,6 +215,8 @@ func (s State) String() string {
211215
return "Server RECVD_CH"
212216
case StateServerNegotiated:
213217
return "Server NEGOTIATED"
218+
case StateServerReadPastEarlyData:
219+
return "Server READ_PAST_EARLY_DATA"
214220
case StateServerWaitEOED:
215221
return "Server WAIT_EOED"
216222
case StateServerWaitFlight2:
@@ -252,3 +258,9 @@ func (e Epoch) label() string {
252258
}
253259
return "Application data (updated)"
254260
}
261+
262+
func assert(b bool) {
263+
if !b {
264+
panic("Assertion failed")
265+
}
266+
}

common_test.go

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"reflect"
88
"runtime"
9+
"sort"
910
"testing"
1011
)
1112

@@ -17,7 +18,7 @@ func unhex(h string) []byte {
1718
return b
1819
}
1920

20-
func assert(t *testing.T, test bool, msg string) {
21+
func assertTrue(t *testing.T, test bool, msg string) {
2122
t.Helper()
2223
prefix := string("")
2324
for i := 1; ; i++ {
@@ -34,40 +35,40 @@ func assert(t *testing.T, test bool, msg string) {
3435

3536
func assertError(t *testing.T, err error, msg string) {
3637
t.Helper()
37-
assert(t, err != nil, msg)
38+
assertTrue(t, err != nil, msg)
3839
}
3940

4041
func assertNotError(t *testing.T, err error, msg string) {
4142
t.Helper()
4243
if err != nil {
4344
msg += ": " + err.Error()
4445
}
45-
assert(t, err == nil, msg)
46+
assertTrue(t, err == nil, msg)
4647
}
4748

4849
func assertNil(t *testing.T, x interface{}, msg string) {
4950
t.Helper()
50-
assert(t, x == nil, msg)
51+
assertTrue(t, x == nil, msg)
5152
}
5253

5354
func assertNotNil(t *testing.T, x interface{}, msg string) {
5455
t.Helper()
55-
assert(t, x != nil, msg)
56+
assertTrue(t, x != nil, msg)
5657
}
5758

5859
func assertEquals(t *testing.T, a, b interface{}) {
5960
t.Helper()
60-
assert(t, a == b, fmt.Sprintf("%+v != %+v", a, b))
61+
assertTrue(t, a == b, fmt.Sprintf("%+v != %+v", a, b))
6162
}
6263

6364
func assertByteEquals(t *testing.T, a, b []byte) {
6465
t.Helper()
65-
assert(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
66+
assertTrue(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
6667
}
6768

6869
func assertNotByteEquals(t *testing.T, a, b []byte) {
6970
t.Helper()
70-
assert(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
71+
assertTrue(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
7172
}
7273

7374
func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
@@ -81,12 +82,61 @@ func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
8182

8283
func assertDeepEquals(t *testing.T, a, b interface{}) {
8384
t.Helper()
84-
assert(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b))
85+
assertTrue(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b))
8586
}
8687

8788
func assertSameType(t *testing.T, a, b interface{}) {
8889
t.Helper()
8990
A := reflect.TypeOf(a)
9091
B := reflect.TypeOf(b)
91-
assert(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name()))
92+
assertTrue(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name()))
93+
}
94+
95+
// Utilities for parametrized tests
96+
// Represents the configuration for a given test instance.
97+
type testInstanceState map[string]string
98+
99+
// Helper function.
100+
func runParametrizedInner(t *testing.T, name string, state testInstanceState, inparams map[string][]string, inparamList []string, f parametrizedTest) {
101+
102+
paramName := inparamList[0]
103+
param := inparams[paramName]
104+
next := inparamList[1:]
105+
106+
for _, paramVal := range param {
107+
state[paramName] = paramVal
108+
var n string
109+
if len(name) > 0 {
110+
n = name + "/"
111+
}
112+
n = n + paramName + "=" + paramVal
113+
114+
if len(next) == 0 {
115+
t.Run(n, func(t *testing.T) {
116+
f(t, n, state)
117+
})
118+
continue
119+
}
120+
runParametrizedInner(t, n, state, inparams, next, f)
121+
}
122+
}
123+
124+
// Nominally public API.
125+
type testParameter struct {
126+
name string
127+
vals []string
128+
}
129+
130+
type parametrizedTest func(t *testing.T, name string, p testInstanceState)
131+
132+
// This is the function you call.
133+
func runParametrizedTest(t *testing.T, inparams map[string][]string, f parametrizedTest) {
134+
// Make a sorted list of the names, so we get a consistent order.
135+
il := make([]string, 0)
136+
for k := range inparams {
137+
il = append(il, k)
138+
}
139+
sort.Slice(il, func(i, j int) bool { return il[i] < il[j] })
140+
141+
runParametrizedInner(t, "", make(map[string]string), inparams, il, f)
92142
}

0 commit comments

Comments
 (0)