Skip to content

Commit 1135b16

Browse files
committed
新v3函数+修改limit逻辑
1 parent fc13e78 commit 1135b16

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

frame/read_frame_bufio_v2.go

+32
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package frame
1515

1616
import (
17+
"fmt"
1718
"io"
1819

1920
"github.com/antlabs/wsutil/enum"
@@ -22,15 +23,46 @@ import (
2223

2324
func ReadFrameFromReaderV2(r io.Reader, headArray *[enum.MaxFrameHeaderSize]byte, buf *[]byte) (f Frame2, err error) {
2425
h, _, err := ReadHeader(r, headArray)
26+
if err != nil {
27+
return f, fmt.Errorf("ReadFrameFromReaderV2:%w", err)
28+
}
29+
30+
if cap(*buf) < int(h.PayloadLen) {
31+
// TODO sync.Pool 处理
32+
*buf = make([]byte, h.PayloadLen)
33+
}
34+
*buf = (*buf)[:h.PayloadLen]
35+
n1, err := io.ReadFull(r, *buf)
2536
if err != nil {
2637
return f, err
2738
}
39+
if n1 != int(h.PayloadLen) {
40+
return f, io.ErrUnexpectedEOF
41+
}
42+
f.Payload = buf
43+
f.FrameHeader = h
44+
if h.Mask {
45+
mask.Mask(*f.Payload, h.MaskKey)
46+
}
47+
48+
return f, nil
49+
}
50+
51+
func ReadFrameFromReaderV3(r io.Reader, lr io.Reader, headArray *[enum.MaxFrameHeaderSize]byte, buf *[]byte) (f Frame2, err error) {
52+
h, _, err := ReadHeader(r, headArray)
53+
if err != nil {
54+
return f, fmt.Errorf("ReadFrameFromReaderV2:%w", err)
55+
}
2856

2957
if cap(*buf) < int(h.PayloadLen) {
3058
// TODO sync.Pool 处理
3159
*buf = make([]byte, h.PayloadLen)
3260
}
3361
*buf = (*buf)[:h.PayloadLen]
62+
if lr != nil {
63+
r = lr
64+
}
65+
3466
n1, err := io.ReadFull(r, *buf)
3567
if err != nil {
3668
return f, err

limitreader/limitreader.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,27 @@ func NewLimitReader(r io.Reader, m int64) *limitReader {
1717
return &limitReader{r: r, m: m}
1818
}
1919

20+
// 目前go.mod使用的是go1.20版本, go1.21才有min函数
21+
func minSize(n, m int) int {
22+
if n < m {
23+
return n
24+
}
25+
return m
26+
}
27+
2028
// 实现io.Reader接口
2129
func (l *limitReader) Read(p []byte) (n int, err error) {
2230
if l.m < 0 {
2331
return 0, ErrTooBigMessage
2432
}
25-
n, err = l.r.Read(p)
33+
rn := minSize(int(l.m), len(p))
34+
if rn == 0 && len(p) > 0 {
35+
rn = 1
36+
}
37+
n, err = l.r.Read(p[:rn])
2638
l.m -= int64(n)
39+
if l.m < 0 {
40+
return 0, ErrTooBigMessage
41+
}
2742
return
2843
}

0 commit comments

Comments
 (0)