Skip to content

Commit

Permalink
Fix: Data race in topic message query (#787)
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuel Pargov <[email protected]>
  • Loading branch information
bamzedev authored Aug 14, 2023
1 parent 7055eca commit 0ba653f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
HEDERA_NETWORK: testnet
OPERATOR_ID: ${{ secrets.TESTNET_OPERATOR_ID }}
OPERATOR_KEY: ${{ secrets.TESTNET_OPERATOR_KEY }}
run: go test -tags="testnets" -timeout 9999s -v -coverprofile=testnets.out -covermode=atomic
run: go test -tags="testnets" -timeout 9999s -v -coverprofile=testnets.out -covermode=atomic -race

- name: Upload coverage to Codecov
if: success()
Expand Down
12 changes: 10 additions & 2 deletions topic_message_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"io"
"math"
"regexp"
"sync"
"time"

"github.com/hashgraph/hedera-protobufs-go/services"
Expand All @@ -48,6 +49,7 @@ type TopicMessageQuery struct {
startTime *time.Time
endTime *time.Time
limit uint64
mu sync.Mutex
}

// NewTopicMessageQuery creates TopicMessageQuery which
Expand Down Expand Up @@ -185,6 +187,8 @@ func (query *TopicMessageQuery) _Build() *mirror.ConsensusTopicQuery {

// Subscribe subscribes to messages sent to the specific TopicID
func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessage)) (SubscriptionHandle, error) {
var once sync.Once
done := make(chan struct{})
handle := SubscriptionHandle{}

err := query._ValidateNetworkOnIDs(client)
Expand All @@ -202,6 +206,8 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa
}

go func() {
query.mu.Lock()
defer query.mu.Unlock()
var subClient mirror.ConsensusService_SubscribeTopicClient
var err error

Expand Down Expand Up @@ -231,7 +237,9 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa
if subClient == nil {
ctx, cancel := context.WithCancel(context.TODO())
handle.onUnsubscribe = cancel

once.Do(func() {
close(done)
})
subClient, err = (*channel).SubscribeTopic(ctx, pb)

if err != nil {
Expand Down Expand Up @@ -274,7 +282,7 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa
}
}
}()

<-done
return handle, nil
}

Expand Down
55 changes: 30 additions & 25 deletions topic_message_query_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package hedera

import (
"errors"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -85,6 +86,7 @@ Etiam ut sodales ex. Nulla luctus, magna eu scelerisque sagittis, nibh quam cons
func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)
var finished int32 // 0 for false, 1 for true

resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
Expand All @@ -99,15 +101,14 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

finished := false
start := time.Now()

_, err = NewTopicMessageQuery().
SetTopicID(topicID).
SetStartTime(time.Unix(0, 0)).
SetLimit(14).
SetLimit(1).
SetCompletionHandler(func() {
finished = true
atomic.StoreInt32(&finished, 1)
}).
Subscribe(env.Client, func(message TopicMessage) {
// Do nothing
Expand All @@ -125,7 +126,8 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
require.NoError(t, err)

for {
if finished || uint64(time.Since(start).Seconds()) > 60 {
condition := atomic.LoadInt32(&finished) == 1 || uint64(time.Since(start).Seconds()) > 60
if condition {
break
}

Expand All @@ -141,7 +143,7 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if !finished {
if atomic.LoadInt32(&finished) != 1 {
err = errors.New("Message was not received within 60 seconds")
}
require.NoError(t, err)
Expand All @@ -153,7 +155,7 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)

var wait int32 = 1 // 1 for true, 0 for false
resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
SetNodeAccountIDs(env.NodeAccountIDs).
Expand All @@ -167,15 +169,14 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

wait := true
start := time.Now()

_, err = NewTopicMessageQuery().
SetTopicID(topicID).
SetStartTime(time.Unix(0, 0)).
Subscribe(env.Client, func(message TopicMessage) {
if string(message.Contents) == bigContents {
wait = false
atomic.StoreInt32(&wait, 0)
}
})
require.NoError(t, err)
Expand All @@ -192,7 +193,8 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
}

for {
if err != nil || !wait || uint64(time.Since(start).Seconds()) > 30 {
condition := atomic.LoadInt32(&wait) == 0 || err != nil || uint64(time.Since(start).Seconds()) > 30
if condition {
break
}

Expand All @@ -208,7 +210,7 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if wait {
if atomic.LoadInt32(&wait) == 1 {
err = errors.New("Message was not received within 30 seconds")
}
assert.Error(t, err)
Expand All @@ -220,7 +222,7 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)

var wait int32 = 1 // 1 for true, 0 for false
resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
SetNodeAccountIDs(env.NodeAccountIDs).
Expand All @@ -234,15 +236,14 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

wait := true
start := time.Now()

_, err = NewTopicMessageQuery().
SetTopicID(topicID).
SetStartTime(time.Unix(0, 0)).
Subscribe(env.Client, func(message TopicMessage) {
if string(message.Contents) == bigContents {
wait = false
atomic.StoreInt32(&wait, 0)
}
})
require.NoError(t, err)
Expand All @@ -257,7 +258,8 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
}

for {
if err != nil || !wait || uint64(time.Since(start).Seconds()) > 30 {
condition := atomic.LoadInt32(&wait) == 0 || err != nil || uint64(time.Since(start).Seconds()) > 30
if condition {
break
}

Expand All @@ -273,9 +275,10 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if wait {
if atomic.LoadInt32(&wait) == 1 {
err = errors.New("Message was not received within 30 seconds")
}

assert.Error(t, err)

err = CloseIntegrationTestEnv(env, nil)
Expand All @@ -285,6 +288,7 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)
var finished int32 = 0 // 0 for false, 1 for true

resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
Expand All @@ -299,7 +303,6 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

finished := false
start := time.Now()

_, err = NewTopicMessageQuery().
Expand All @@ -308,9 +311,10 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
SetEndTime(time.Now().Add(time.Second*20)).
Subscribe(env.Client, func(message TopicMessage) {
if string(message.Contents) == bigContents {
finished = true
atomic.StoreInt32(&finished, 1)
}
})

require.NoError(t, err)

resp, err = NewTopicMessageSubmitTransaction().
Expand All @@ -324,7 +328,8 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
require.NoError(t, err)

for {
if err != nil || finished || uint64(time.Since(start).Seconds()) > 60 {
condition := atomic.LoadInt32(&finished) == 1 || uint64(time.Since(start).Seconds()) > 60
if condition {
break
}

Expand All @@ -340,9 +345,10 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if !finished {
if atomic.LoadInt32(&finished) == 0 {
err = errors.New("Message was not received within 60 seconds")
}

assert.NoError(t, err)

err = CloseIntegrationTestEnv(env, nil)
Expand All @@ -353,8 +359,8 @@ func TestIntegrationTopicMessageQueryCanExecuteWithTls(t *testing.T) {
client := ClientForNetwork(map[string]AccountID{})
client.SetMirrorNetwork([]string{"mainnet-public.mirrornode.hedera.com:443"})
client.SetTransportSecurity(true)
var finished int32 = 0 // 0 for false, 1 for true

finished := false
start := time.Now()
end := start.Add(5 * time.Second)

Expand All @@ -363,21 +369,20 @@ func TestIntegrationTopicMessageQueryCanExecuteWithTls(t *testing.T) {
SetStartTime(time.Unix(0, 0)).
SetLimit(10).
SetCompletionHandler(func() {
finished = true
atomic.StoreInt32(&finished, 1)
}).
Subscribe(client, func(message TopicMessage) {
})
require.NoError(t, err)

for {
if finished || time.Now().After(end) {
condition := atomic.LoadInt32(&finished) == 1 || time.Now().After(end)
if condition {
break
}

time.Sleep(2 * time.Second)
}

require.True(t, finished)
require.True(t, atomic.LoadInt32(&finished) == 1)

handle.Unsubscribe()
}

0 comments on commit 0ba653f

Please sign in to comment.