diff --git a/credentials/alts/internal/conn/common.go b/credentials/alts/internal/conn/common.go index 1795d0c9e372..46617132a456 100644 --- a/credentials/alts/internal/conn/common.go +++ b/credentials/alts/internal/conn/common.go @@ -54,11 +54,10 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) { func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) { // If the size field is not complete, return the provided buffer as // remaining buffer. - if len(b) < MsgLenFieldSize { + length, sufficientBytes := parseMessageLength(b) + if !sufficientBytes { return nil, b, nil } - msgLenField := b[:MsgLenFieldSize] - length := binary.LittleEndian.Uint32(msgLenField) if length > maxLen { return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen) } @@ -68,3 +67,14 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) { } return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil } + +// parseMessageLength returns the message length based on frame header. It also +// returns a boolean indicating if the buffer contains sufficient bytes to parse +// the length header. If there are insufficient bytes, (0, false) is returned. +func parseMessageLength(b []byte) (uint32, bool) { + if len(b) < MsgLenFieldSize { + return 0, false + } + msgLenField := b[:MsgLenFieldSize] + return binary.LittleEndian.Uint32(msgLenField), true +} diff --git a/credentials/alts/internal/conn/record.go b/credentials/alts/internal/conn/record.go index f1ea7bb20811..d9a18b7f74c7 100644 --- a/credentials/alts/internal/conn/record.go +++ b/credentials/alts/internal/conn/record.go @@ -63,6 +63,8 @@ const ( // The maximum write buffer size. This *must* be multiple of // altsRecordDefaultLength. altsWriteBufferMaxSize = 512 * 1024 // 512KiB + // The initial buffer used to read from the network. + altsReadBufferInitialSize = 32 * 1024 // 32KiB ) var ( @@ -83,7 +85,7 @@ type conn struct { net.Conn crypto ALTSRecordCrypto // buf holds data that has been read from the connection and decrypted, - // but has not yet been returned by Read. + // but has not yet been returned by Read. It is a sub-slice of protected. buf []byte payloadLengthLimit int // protected holds data read from the network but have not yet been @@ -111,21 +113,13 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot } overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead() payloadLengthLimit := altsRecordDefaultLength - overhead - var protectedBuf []byte - if protected == nil { - // We pre-allocate protected to be of size - // 2*altsRecordDefaultLength-1 during initialization. We only - // read from the network into protected when protected does not - // contain a complete frame, which is at most - // altsRecordDefaultLength-1 (bytes). And we read at most - // altsRecordDefaultLength (bytes) data into protected at one - // time. Therefore, 2*altsRecordDefaultLength-1 is large enough - // to buffer data read from the network. - protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1) - } else { - protectedBuf = make([]byte, len(protected)) - copy(protectedBuf, protected) - } + // We pre-allocate protected to be of size 32KB during initialization. + // We increase the size of the buffer by the required amount if it can't + // hold a complete encrypted record. + protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected))) + // Copy additional data from hanshaker service. + copy(protectedBuf, protected) + protectedBuf = protectedBuf[:len(protected)] altsConn := &conn{ Conn: c, @@ -162,11 +156,21 @@ func (p *conn) Read(b []byte) (n int, err error) { // Check whether a complete frame has been received yet. for len(framedMsg) == 0 { if len(p.protected) == cap(p.protected) { - tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength) - copy(tmp, p.protected) - p.protected = tmp + // We can parse the length header to know exactly how large + // the buffer needs to be to hold the entire frame. + length, didParse := parseMessageLength(p.protected) + if !didParse { + // The protected buffer is initialized with a capacity of + // larger than 4B. It should always hold the message length + // header. + panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize)) + } + oldProtectedBuf := p.protected + p.protected = make([]byte, int(length)+MsgLenFieldSize) + copy(p.protected, oldProtectedBuf) + p.protected = p.protected[:len(oldProtectedBuf)] } - n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)]) + n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)]) if err != nil { return 0, err } @@ -185,6 +189,15 @@ func (p *conn) Read(b []byte) (n int, err error) { } ciphertext := msg[msgTypeFieldSize:] + // Decrypt directly into the buffer, avoiding a copy from p.buf if + // possible. + if len(b) >= len(ciphertext) { + dec, err := p.crypto.Decrypt(b[:0], ciphertext) + if err != nil { + return 0, err + } + return len(dec), nil + } // Decrypt requires that if the dst and ciphertext alias, they // must alias exactly. Code here used to use msg[:0], but msg // starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than diff --git a/credentials/alts/internal/conn/record_test.go b/credentials/alts/internal/conn/record_test.go index c50fb4c82251..e4992489a189 100644 --- a/credentials/alts/internal/conn/record_test.go +++ b/credentials/alts/internal/conn/record_test.go @@ -26,6 +26,7 @@ import ( "math" "net" "reflect" + "strings" "testing" core "google.golang.org/grpc/credentials/alts/internal" @@ -188,6 +189,48 @@ func (s) TestLargeMsg(t *testing.T) { } } +// TestLargeRecord writes a very large ALTS record and verifies that the server +// receives it correctly. The large ALTS record should cause the reader to +// expand it's read buffer to hold the entire record and store the decrypted +// message until the receiver reads all of the bytes. +func (s) TestLargeRecord(t *testing.T) { + clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil) + msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize)) + // Increase the size of ALTS records written by the client. + clientConn.payloadLengthLimit = math.MaxInt32 + if n, err := clientConn.Write(msg); n != len(msg) || err != nil { + t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) + } + rcvMsg := make([]byte, len(msg)) + if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { + t.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg)) + } + if !reflect.DeepEqual(msg, rcvMsg) { + t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg) + } +} + +// BenchmarkLargeMessage measures the performance of ALTS conns for sending and +// receiving a large message. +func BenchmarkLargeMessage(b *testing.B) { + msgLen := 20 * 1024 * 1024 // 20 MiB + msg := make([]byte, msgLen) + rcvMsg := make([]byte, len(msg)) + b.ResetTimer() + clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil) + for range b.N { + // Write 20 MiB 5 times to transfer a total of 100 MiB. + for range 5 { + if n, err := clientConn.Write(msg); n != len(msg) || err != nil { + b.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) + } + if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { + b.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg)) + } + } + } +} + func testIncorrectMsgType(t *testing.T, rp string) { // framedMsg is an empty ciphertext with correct framing but wrong // message type. diff --git a/credentials/alts/internal/handshaker/handshaker.go b/credentials/alts/internal/handshaker/handshaker.go index 50721f690acb..becd2f3bdf3e 100644 --- a/credentials/alts/internal/handshaker/handshaker.go +++ b/credentials/alts/internal/handshaker/handshaker.go @@ -308,6 +308,7 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al // whatever received from the network and send it to the handshaker service. func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) { var lastWriteTime time.Time + buf := make([]byte, frameLimit) for { if len(resp.OutFrames) > 0 { lastWriteTime = time.Now() @@ -318,7 +319,6 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b if resp.Result != nil { return resp.Result, extra, nil } - buf := make([]byte, frameLimit) n, err := h.conn.Read(buf) if err != nil && err != io.EOF { return nil, nil, err