Skip to content

credentials/alts: Optimize reads #8204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions credentials/alts/internal/conn/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
if len(b) < MsgLenFieldSize {
length, sufficientBytes := parseMessageLength(b)
if !sufficientBytes {
return nil, b, nil
}
msgLenField := b[:MsgLenFieldSize]
length := binary.LittleEndian.Uint32(msgLenField)
if length > maxLen {
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
Expand All @@ -68,3 +67,14 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
}
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
}

// parseMessageLength returns the message length based on frame header. It also
// returns a boolean indicating if the buffer contains sufficient bytes to parse
// the length header. If there are insufficient bytes, (0, false) is returned.
func parseMessageLength(b []byte) (uint32, bool) {
if len(b) < MsgLenFieldSize {
return 0, false
}
msgLenField := b[:MsgLenFieldSize]
return binary.LittleEndian.Uint32(msgLenField), true
}
53 changes: 33 additions & 20 deletions credentials/alts/internal/conn/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
// The maximum write buffer size. This *must* be multiple of
// altsRecordDefaultLength.
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
// The initial buffer used to read from the network.
altsReadBufferInitialSize = 32 * 1024 // 32KiB
)

var (
Expand All @@ -83,7 +85,7 @@
net.Conn
crypto ALTSRecordCrypto
// buf holds data that has been read from the connection and decrypted,
// but has not yet been returned by Read.
// but has not yet been returned by Read. It is a sub-slice of protected.
buf []byte
payloadLengthLimit int
// protected holds data read from the network but have not yet been
Expand Down Expand Up @@ -111,21 +113,13 @@
}
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead
var protectedBuf []byte
if protected == nil {
// We pre-allocate protected to be of size
// 2*altsRecordDefaultLength-1 during initialization. We only
// read from the network into protected when protected does not
// contain a complete frame, which is at most
// altsRecordDefaultLength-1 (bytes). And we read at most
// altsRecordDefaultLength (bytes) data into protected at one
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
// to buffer data read from the network.
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
} else {
protectedBuf = make([]byte, len(protected))
copy(protectedBuf, protected)
}
// We pre-allocate protected to be of size 32KB during initialization.
// We increase the size of the buffer by the required amount if it can't
// hold a complete encrypted record.
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
// Copy additional data from hanshaker service.
copy(protectedBuf, protected)
protectedBuf = protectedBuf[:len(protected)]

altsConn := &conn{
Conn: c,
Expand Down Expand Up @@ -162,11 +156,21 @@
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if len(p.protected) == cap(p.protected) {
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
copy(tmp, p.protected)
p.protected = tmp
// We can parse the length header to know exactly how large
// the buffer needs to be to hold the entire frame.
length, didParse := parseMessageLength(p.protected)
if !didParse {
// The protected buffer is initialized with a capacity of
// larger than 4B. It should always hold the message length
// header.
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))

Check warning on line 166 in credentials/alts/internal/conn/record.go

View check run for this annotation

Codecov / codecov/patch

credentials/alts/internal/conn/record.go#L163-L166

Added lines #L163 - L166 were not covered by tests
}
oldProtectedBuf := p.protected
p.protected = make([]byte, int(length)+MsgLenFieldSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get to review this very last changeset - here on 169 we make a specific length int(length)+MsgLenFieldSize, copy to it, then slice to a different var length len(oldProtectedBuf)

I suspected something is not quite right here or can be made a little clearer?

Copy link
Contributor Author

@arjan-bal arjan-bal Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new buffer that we allocate must be able to hold the entire encrypted record. After reading the message header, we know the length of the record is: length parsed from the message header + size of the message length header. This is the capacity, but the length of the new buffer should be set to the number of bytes that are already read. So we set the length to length of the existing buffer and copy its contents. Let me raise a PR with some commentry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a PR to add clarity, PTAL: #8232

copy(p.protected, oldProtectedBuf)
p.protected = p.protected[:len(oldProtectedBuf)]
}
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
if err != nil {
return 0, err
}
Expand All @@ -185,6 +189,15 @@
}
ciphertext := msg[msgTypeFieldSize:]

// Decrypt directly into the buffer, avoiding a copy from p.buf if
// possible.
if len(b) >= len(ciphertext) {
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
if err != nil {
return 0, err
}

Check warning on line 198 in credentials/alts/internal/conn/record.go

View check run for this annotation

Codecov / codecov/patch

credentials/alts/internal/conn/record.go#L197-L198

Added lines #L197 - L198 were not covered by tests
return len(dec), nil
}
// Decrypt requires that if the dst and ciphertext alias, they
// must alias exactly. Code here used to use msg[:0], but msg
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
Expand Down
43 changes: 43 additions & 0 deletions credentials/alts/internal/conn/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"math"
"net"
"reflect"
"strings"
"testing"

core "google.golang.org/grpc/credentials/alts/internal"
Expand Down Expand Up @@ -188,6 +189,48 @@ func (s) TestLargeMsg(t *testing.T) {
}
}

// TestLargeRecord writes a very large ALTS record and verifies that the server
// receives it correctly. The large ALTS record should cause the reader to
// expand it's read buffer to hold the entire record and store the decrypted
// message until the receiver reads all of the bytes.
func (s) TestLargeRecord(t *testing.T) {
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
// Increase the size of ALTS records written by the client.
clientConn.payloadLengthLimit = math.MaxInt32
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
}
}

// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
// receiving a large message.
func BenchmarkLargeMessage(b *testing.B) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gtcooke94 Do we have any GitHub actions that already run these benchmarks or should we add any if not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think someone else on the Go team would be the person to ask here - I don't know much about the CI setup. @arjan-bal do you know about the grpc-go github CI and benchmarking?

Copy link
Contributor Author

@arjan-bal arjan-bal Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't run benchmarks as part of CI. We have benchmarks here that we ask PR authors to run when reviewing PRs that effect performance. We can have a similar benchmark for ALTS or modify the existing benchmark to support ALTS.

We have performance dashboard for all languages here, but we don't have alerts setup for regressions: https://grafana-dot-grpc-testing.appspot.com/?orgId=1

msgLen := 20 * 1024 * 1024 // 20 MiB
msg := make([]byte, msgLen)
rcvMsg := make([]byte, len(msg))
b.ResetTimer()
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
for range b.N {
// Write 20 MiB 5 times to transfer a total of 100 MiB.
for range 5 {
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
}
}
}

func testIncorrectMsgType(t *testing.T, rp string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.
Expand Down
2 changes: 1 addition & 1 deletion credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
var lastWriteTime time.Time
buf := make([]byte, frameLimit)
for {
if len(resp.OutFrames) > 0 {
lastWriteTime = time.Now()
Expand All @@ -318,7 +319,6 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
if resp.Result != nil {
return resp.Result, extra, nil
}
buf := make([]byte, frameLimit)
n, err := h.conn.Read(buf)
if err != nil && err != io.EOF {
return nil, nil, err
Expand Down