From a2a2bcdc518087a8deedc6922ec9613d5ba21da2 Mon Sep 17 00:00:00 2001 From: Emanuel Pargov Date: Mon, 14 Aug 2023 16:07:29 +0300 Subject: [PATCH] Fix: Data race in topic message query Signed-off-by: Emanuel Pargov --- .github/workflows/build.yml | 2 +- topic_message_query.go | 12 +++++-- topic_message_query_e2e_test.go | 55 ++++++++++++++++++--------------- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b0ff74dd4..ef278c089 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -74,7 +74,7 @@ jobs: HEDERA_NETWORK: testnet OPERATOR_ID: ${{ secrets.TESTNET_OPERATOR_ID }} OPERATOR_KEY: ${{ secrets.TESTNET_OPERATOR_KEY }} - run: go test -tags="testnets" -timeout 9999s -v -coverprofile=testnets.out -covermode=atomic + run: go test -tags="testnets" -timeout 9999s -v -coverprofile=testnets.out -covermode=atomic -race - name: Upload coverage to Codecov if: success() diff --git a/topic_message_query.go b/topic_message_query.go index 550209657..cf1bb419d 100644 --- a/topic_message_query.go +++ b/topic_message_query.go @@ -25,6 +25,7 @@ import ( "io" "math" "regexp" + "sync" "time" "github.com/hashgraph/hedera-protobufs-go/services" @@ -48,6 +49,7 @@ type TopicMessageQuery struct { startTime *time.Time endTime *time.Time limit uint64 + mu sync.Mutex } // NewTopicMessageQuery creates TopicMessageQuery which @@ -185,6 +187,8 @@ func (query *TopicMessageQuery) _Build() *mirror.ConsensusTopicQuery { // Subscribe subscribes to messages sent to the specific TopicID func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessage)) (SubscriptionHandle, error) { + var once sync.Once + done := make(chan struct{}) handle := SubscriptionHandle{} err := query._ValidateNetworkOnIDs(client) @@ -202,6 +206,8 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa } go func() { + query.mu.Lock() + defer query.mu.Unlock() var subClient mirror.ConsensusService_SubscribeTopicClient var err error @@ -231,7 +237,9 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa if subClient == nil { ctx, cancel := context.WithCancel(context.TODO()) handle.onUnsubscribe = cancel - + once.Do(func() { + close(done) + }) subClient, err = (*channel).SubscribeTopic(ctx, pb) if err != nil { @@ -274,7 +282,7 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa } } }() - + <-done return handle, nil } diff --git a/topic_message_query_e2e_test.go b/topic_message_query_e2e_test.go index 573733216..69d4fcf88 100644 --- a/topic_message_query_e2e_test.go +++ b/topic_message_query_e2e_test.go @@ -25,6 +25,7 @@ package hedera import ( "errors" + "sync/atomic" "testing" "time" @@ -85,6 +86,7 @@ Etiam ut sodales ex. Nulla luctus, magna eu scelerisque sagittis, nibh quam cons func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) { t.Parallel() env := NewIntegrationTestEnv(t) + var finished int32 // 0 for false, 1 for true resp, err := NewTopicCreateTransaction(). SetAdminKey(env.Client.GetOperatorPublicKey()). @@ -99,15 +101,14 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) { topicID := *receipt.TopicID assert.NotNil(t, topicID) - finished := false start := time.Now() _, err = NewTopicMessageQuery(). SetTopicID(topicID). SetStartTime(time.Unix(0, 0)). - SetLimit(14). + SetLimit(1). SetCompletionHandler(func() { - finished = true + atomic.StoreInt32(&finished, 1) }). Subscribe(env.Client, func(message TopicMessage) { // Do nothing @@ -125,7 +126,8 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) { require.NoError(t, err) for { - if finished || uint64(time.Since(start).Seconds()) > 60 { + condition := atomic.LoadInt32(&finished) == 1 || uint64(time.Since(start).Seconds()) > 60 + if condition { break } @@ -141,7 +143,7 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) { _, err = resp.SetValidateStatus(true).GetReceipt(env.Client) require.NoError(t, err) - if !finished { + if atomic.LoadInt32(&finished) != 1 { err = errors.New("Message was not received within 60 seconds") } require.NoError(t, err) @@ -153,7 +155,7 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) { func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) { t.Parallel() env := NewIntegrationTestEnv(t) - + var wait int32 = 1 // 1 for true, 0 for false resp, err := NewTopicCreateTransaction(). SetAdminKey(env.Client.GetOperatorPublicKey()). SetNodeAccountIDs(env.NodeAccountIDs). @@ -167,7 +169,6 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) { topicID := *receipt.TopicID assert.NotNil(t, topicID) - wait := true start := time.Now() _, err = NewTopicMessageQuery(). @@ -175,7 +176,7 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) { SetStartTime(time.Unix(0, 0)). Subscribe(env.Client, func(message TopicMessage) { if string(message.Contents) == bigContents { - wait = false + atomic.StoreInt32(&wait, 0) } }) require.NoError(t, err) @@ -192,7 +193,8 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) { } for { - if err != nil || !wait || uint64(time.Since(start).Seconds()) > 30 { + condition := atomic.LoadInt32(&wait) == 0 || err != nil || uint64(time.Since(start).Seconds()) > 30 + if condition { break } @@ -208,7 +210,7 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) { _, err = resp.SetValidateStatus(true).GetReceipt(env.Client) require.NoError(t, err) - if wait { + if atomic.LoadInt32(&wait) == 1 { err = errors.New("Message was not received within 30 seconds") } assert.Error(t, err) @@ -220,7 +222,7 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) { func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) { t.Parallel() env := NewIntegrationTestEnv(t) - + var wait int32 = 1 // 1 for true, 0 for false resp, err := NewTopicCreateTransaction(). SetAdminKey(env.Client.GetOperatorPublicKey()). SetNodeAccountIDs(env.NodeAccountIDs). @@ -234,7 +236,6 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) { topicID := *receipt.TopicID assert.NotNil(t, topicID) - wait := true start := time.Now() _, err = NewTopicMessageQuery(). @@ -242,7 +243,7 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) { SetStartTime(time.Unix(0, 0)). Subscribe(env.Client, func(message TopicMessage) { if string(message.Contents) == bigContents { - wait = false + atomic.StoreInt32(&wait, 0) } }) require.NoError(t, err) @@ -257,7 +258,8 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) { } for { - if err != nil || !wait || uint64(time.Since(start).Seconds()) > 30 { + condition := atomic.LoadInt32(&wait) == 0 || err != nil || uint64(time.Since(start).Seconds()) > 30 + if condition { break } @@ -273,9 +275,10 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) { _, err = resp.SetValidateStatus(true).GetReceipt(env.Client) require.NoError(t, err) - if wait { + if atomic.LoadInt32(&wait) == 1 { err = errors.New("Message was not received within 30 seconds") } + assert.Error(t, err) err = CloseIntegrationTestEnv(env, nil) @@ -285,6 +288,7 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) { func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) { t.Parallel() env := NewIntegrationTestEnv(t) + var finished int32 = 0 // 0 for false, 1 for true resp, err := NewTopicCreateTransaction(). SetAdminKey(env.Client.GetOperatorPublicKey()). @@ -299,7 +303,6 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) { topicID := *receipt.TopicID assert.NotNil(t, topicID) - finished := false start := time.Now() _, err = NewTopicMessageQuery(). @@ -308,9 +311,10 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) { SetEndTime(time.Now().Add(time.Second*20)). Subscribe(env.Client, func(message TopicMessage) { if string(message.Contents) == bigContents { - finished = true + atomic.StoreInt32(&finished, 1) } }) + require.NoError(t, err) resp, err = NewTopicMessageSubmitTransaction(). @@ -324,7 +328,8 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) { require.NoError(t, err) for { - if err != nil || finished || uint64(time.Since(start).Seconds()) > 60 { + condition := atomic.LoadInt32(&finished) == 1 || uint64(time.Since(start).Seconds()) > 60 + if condition { break } @@ -340,9 +345,10 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) { _, err = resp.SetValidateStatus(true).GetReceipt(env.Client) require.NoError(t, err) - if !finished { + if atomic.LoadInt32(&finished) == 0 { err = errors.New("Message was not received within 60 seconds") } + assert.NoError(t, err) err = CloseIntegrationTestEnv(env, nil) @@ -353,8 +359,8 @@ func TestIntegrationTopicMessageQueryCanExecuteWithTls(t *testing.T) { client := ClientForNetwork(map[string]AccountID{}) client.SetMirrorNetwork([]string{"mainnet-public.mirrornode.hedera.com:443"}) client.SetTransportSecurity(true) + var finished int32 = 0 // 0 for false, 1 for true - finished := false start := time.Now() end := start.Add(5 * time.Second) @@ -363,21 +369,20 @@ func TestIntegrationTopicMessageQueryCanExecuteWithTls(t *testing.T) { SetStartTime(time.Unix(0, 0)). SetLimit(10). SetCompletionHandler(func() { - finished = true + atomic.StoreInt32(&finished, 1) }). Subscribe(client, func(message TopicMessage) { }) require.NoError(t, err) - for { - if finished || time.Now().After(end) { + condition := atomic.LoadInt32(&finished) == 1 || time.Now().After(end) + if condition { break } - time.Sleep(2 * time.Second) } - require.True(t, finished) + require.True(t, atomic.LoadInt32(&finished) == 1) handle.Unsubscribe() }