From a580c15f6a6b96973fbb5a47b2083eac276e4e82 Mon Sep 17 00:00:00 2001 From: "fengyun.rui" Date: Fri, 10 Nov 2023 22:02:02 +0800 Subject: [PATCH] perf: improve gzip performance with sync.pool (#1321) Signed-off-by: rfyiamcool Co-authored-by: Gordon <46924906+FGadvancer@users.noreply.github.com> --- internal/msggateway/client.go | 4 +- internal/msggateway/compressor.go | 45 +++++++++++ internal/msggateway/compressor_test.go | 107 +++++++++++++++++++++++++ internal/msggateway/init.go | 2 +- 4 files changed, 155 insertions(+), 3 deletions(-) create mode 100644 internal/msggateway/compressor_test.go diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 9eeac38356..69b49d81ac 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -167,7 +167,7 @@ func (c *Client) readMessage() { func (c *Client) handleMessage(message []byte) error { if c.IsCompress { var err error - message, err = c.longConnServer.DeCompress(message) + message, err = c.longConnServer.DecompressWithPool(message) if err != nil { return utils.Wrap(err, "") } @@ -317,7 +317,7 @@ func (c *Client) writeBinaryMsg(resp Resp) error { _ = c.conn.SetWriteDeadline(writeWait) if c.IsCompress { - resultBuf, compressErr := c.longConnServer.Compress(encodedBuf) + resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf) if compressErr != nil { return utils.Wrap(compressErr, "") } diff --git a/internal/msggateway/compressor.go b/internal/msggateway/compressor.go index 0639fb4f01..ae5e9cdd04 100644 --- a/internal/msggateway/compressor.go +++ b/internal/msggateway/compressor.go @@ -17,14 +17,23 @@ package msggateway import ( "bytes" "compress/gzip" + "errors" "io" + "sync" "github.com/OpenIMSDK/tools/utils" ) +var ( + gzipWriterPool = sync.Pool{New: func() any { return gzip.NewWriter(nil) }} + gzipReaderPool = sync.Pool{New: func() any { return new(gzip.Reader) }} +) + type Compressor interface { Compress(rawData []byte) ([]byte, error) + CompressWithPool(rawData []byte) ([]byte, error) DeCompress(compressedData []byte) ([]byte, error) + DecompressWithPool(compressedData []byte) ([]byte, error) } type GzipCompressor struct { compressProtocol string @@ -46,6 +55,22 @@ func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) { return gzipBuffer.Bytes(), nil } +func (g *GzipCompressor) CompressWithPool(rawData []byte) ([]byte, error) { + gz := gzipWriterPool.Get().(*gzip.Writer) + defer gzipWriterPool.Put(gz) + + gzipBuffer := bytes.Buffer{} + gz.Reset(&gzipBuffer) + + if _, err := gz.Write(rawData); err != nil { + return nil, utils.Wrap(err, "") + } + if err := gz.Close(); err != nil { + return nil, utils.Wrap(err, "") + } + return gzipBuffer.Bytes(), nil +} + func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) { buff := bytes.NewBuffer(compressedData) reader, err := gzip.NewReader(buff) @@ -59,3 +84,23 @@ func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) { _ = reader.Close() return compressedData, nil } + +func (g *GzipCompressor) DecompressWithPool(compressedData []byte) ([]byte, error) { + reader := gzipReaderPool.Get().(*gzip.Reader) + if reader == nil { + return nil, errors.New("NewReader failed") + } + defer gzipReaderPool.Put(reader) + + err := reader.Reset(bytes.NewReader(compressedData)) + if err != nil { + return nil, utils.Wrap(err, "NewReader failed") + } + + compressedData, err = io.ReadAll(reader) + if err != nil { + return nil, utils.Wrap(err, "ReadAll failed") + } + _ = reader.Close() + return compressedData, nil +} diff --git a/internal/msggateway/compressor_test.go b/internal/msggateway/compressor_test.go new file mode 100644 index 0000000000..d41c57bf3f --- /dev/null +++ b/internal/msggateway/compressor_test.go @@ -0,0 +1,107 @@ +package msggateway + +import ( + "crypto/rand" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func mockRandom() []byte { + bs := make([]byte, 50) + rand.Read(bs) + return bs +} + +func TestCompressDecompress(t *testing.T) { + + compressor := NewGzipCompressor() + + for i := 0; i < 2000; i++ { + src := mockRandom() + + // compress + dest, err := compressor.CompressWithPool(src) + assert.Equal(t, nil, err) + + // decompress + res, err := compressor.DecompressWithPool(dest) + assert.Equal(t, nil, err) + + // check + assert.EqualValues(t, src, res) + } +} + +func TestCompressDecompressWithConcurrency(t *testing.T) { + wg := sync.WaitGroup{} + compressor := NewGzipCompressor() + + for i := 0; i < 200; i++ { + wg.Add(1) + go func() { + defer wg.Done() + src := mockRandom() + + // compress + dest, err := compressor.CompressWithPool(src) + assert.Equal(t, nil, err) + + // decompress + res, err := compressor.DecompressWithPool(dest) + assert.Equal(t, nil, err) + + // check + assert.EqualValues(t, src, res) + + }() + } + wg.Wait() +} + +func BenchmarkCompress(b *testing.B) { + src := mockRandom() + compressor := NewGzipCompressor() + + for i := 0; i < b.N; i++ { + _, err := compressor.Compress(src) + assert.Equal(b, nil, err) + } +} + +func BenchmarkCompressWithSyncPool(b *testing.B) { + src := mockRandom() + + compressor := NewGzipCompressor() + for i := 0; i < b.N; i++ { + _, err := compressor.CompressWithPool(src) + assert.Equal(b, nil, err) + } +} + +func BenchmarkDecompress(b *testing.B) { + src := mockRandom() + + compressor := NewGzipCompressor() + comdata, err := compressor.Compress(src) + assert.Equal(b, nil, err) + + for i := 0; i < b.N; i++ { + _, err := compressor.DeCompress(comdata) + assert.Equal(b, nil, err) + } +} + +func BenchmarkDecompressWithSyncPool(b *testing.B) { + src := mockRandom() + + compressor := NewGzipCompressor() + comdata, err := compressor.Compress(src) + assert.Equal(b, nil, err) + + for i := 0; i < b.N; i++ { + _, err := compressor.DecompressWithPool(comdata) + assert.Equal(b, nil, err) + } +} diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index 94f1b20118..12a6d37703 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -23,7 +23,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) -// RunWsAndServer run ws server +// RunWsAndServer run ws server. func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { fmt.Println( "start rpc/msg_gateway server, port: ",