Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Data race in topic message query #787

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
HEDERA_NETWORK: testnet
OPERATOR_ID: ${{ secrets.TESTNET_OPERATOR_ID }}
OPERATOR_KEY: ${{ secrets.TESTNET_OPERATOR_KEY }}
run: go test -tags="testnets" -timeout 9999s -v -coverprofile=testnets.out -covermode=atomic
run: go test -tags="testnets" -timeout 9999s -v -coverprofile=testnets.out -covermode=atomic -race

- name: Upload coverage to Codecov
if: success()
Expand Down
12 changes: 10 additions & 2 deletions topic_message_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"io"
"math"
"regexp"
"sync"
"time"

"github.com/hashgraph/hedera-protobufs-go/services"
Expand All @@ -48,6 +49,7 @@ type TopicMessageQuery struct {
startTime *time.Time
endTime *time.Time
limit uint64
mu sync.Mutex
}

// NewTopicMessageQuery creates TopicMessageQuery which
Expand Down Expand Up @@ -185,6 +187,8 @@ func (query *TopicMessageQuery) _Build() *mirror.ConsensusTopicQuery {

// Subscribe subscribes to messages sent to the specific TopicID
func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessage)) (SubscriptionHandle, error) {
var once sync.Once
done := make(chan struct{})
handle := SubscriptionHandle{}

err := query._ValidateNetworkOnIDs(client)
Expand All @@ -202,6 +206,8 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa
}

go func() {
query.mu.Lock()
defer query.mu.Unlock()
var subClient mirror.ConsensusService_SubscribeTopicClient
var err error

Expand Down Expand Up @@ -231,7 +237,9 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa
if subClient == nil {
ctx, cancel := context.WithCancel(context.TODO())
handle.onUnsubscribe = cancel

once.Do(func() {
close(done)
})
subClient, err = (*channel).SubscribeTopic(ctx, pb)

if err != nil {
Expand Down Expand Up @@ -274,7 +282,7 @@ func (query *TopicMessageQuery) Subscribe(client *Client, onNext func(TopicMessa
}
}
}()

<-done
return handle, nil
}

Expand Down
55 changes: 30 additions & 25 deletions topic_message_query_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package hedera

import (
"errors"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -85,6 +86,7 @@ Etiam ut sodales ex. Nulla luctus, magna eu scelerisque sagittis, nibh quam cons
func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)
var finished int32 // 0 for false, 1 for true

resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
Expand All @@ -99,15 +101,14 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

finished := false
start := time.Now()

_, err = NewTopicMessageQuery().
SetTopicID(topicID).
SetStartTime(time.Unix(0, 0)).
SetLimit(14).
SetLimit(1).
SetCompletionHandler(func() {
finished = true
atomic.StoreInt32(&finished, 1)
}).
Subscribe(env.Client, func(message TopicMessage) {
// Do nothing
Expand All @@ -125,7 +126,8 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
require.NoError(t, err)

for {
if finished || uint64(time.Since(start).Seconds()) > 60 {
condition := atomic.LoadInt32(&finished) == 1 || uint64(time.Since(start).Seconds()) > 60
if condition {
break
}

Expand All @@ -141,7 +143,7 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if !finished {
if atomic.LoadInt32(&finished) != 1 {
err = errors.New("Message was not received within 60 seconds")
}
require.NoError(t, err)
Expand All @@ -153,7 +155,7 @@ func TestIntegrationTopicMessageQueryCanExecute(t *testing.T) {
func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)

var wait int32 = 1 // 1 for true, 0 for false
resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
SetNodeAccountIDs(env.NodeAccountIDs).
Expand All @@ -167,15 +169,14 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

wait := true
start := time.Now()

_, err = NewTopicMessageQuery().
SetTopicID(topicID).
SetStartTime(time.Unix(0, 0)).
Subscribe(env.Client, func(message TopicMessage) {
if string(message.Contents) == bigContents {
wait = false
atomic.StoreInt32(&wait, 0)
}
})
require.NoError(t, err)
Expand All @@ -192,7 +193,8 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
}

for {
if err != nil || !wait || uint64(time.Since(start).Seconds()) > 30 {
condition := atomic.LoadInt32(&wait) == 0 || err != nil || uint64(time.Since(start).Seconds()) > 30
if condition {
break
}

Expand All @@ -208,7 +210,7 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if wait {
if atomic.LoadInt32(&wait) == 1 {
err = errors.New("Message was not received within 30 seconds")
}
assert.Error(t, err)
Expand All @@ -220,7 +222,7 @@ func TestIntegrationTopicMessageQueryNoTopicID(t *testing.T) {
func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)

var wait int32 = 1 // 1 for true, 0 for false
resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
SetNodeAccountIDs(env.NodeAccountIDs).
Expand All @@ -234,15 +236,14 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

wait := true
start := time.Now()

_, err = NewTopicMessageQuery().
SetTopicID(topicID).
SetStartTime(time.Unix(0, 0)).
Subscribe(env.Client, func(message TopicMessage) {
if string(message.Contents) == bigContents {
wait = false
atomic.StoreInt32(&wait, 0)
}
})
require.NoError(t, err)
Expand All @@ -257,7 +258,8 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
}

for {
if err != nil || !wait || uint64(time.Since(start).Seconds()) > 30 {
condition := atomic.LoadInt32(&wait) == 0 || err != nil || uint64(time.Since(start).Seconds()) > 30
if condition {
break
}

Expand All @@ -273,9 +275,10 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if wait {
if atomic.LoadInt32(&wait) == 1 {
err = errors.New("Message was not received within 30 seconds")
}

assert.Error(t, err)

err = CloseIntegrationTestEnv(env, nil)
Expand All @@ -285,6 +288,7 @@ func TestIntegrationTopicMessageQueryNoMessage(t *testing.T) {
func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
t.Parallel()
env := NewIntegrationTestEnv(t)
var finished int32 = 0 // 0 for false, 1 for true

resp, err := NewTopicCreateTransaction().
SetAdminKey(env.Client.GetOperatorPublicKey()).
Expand All @@ -299,7 +303,6 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
topicID := *receipt.TopicID
assert.NotNil(t, topicID)

finished := false
start := time.Now()

_, err = NewTopicMessageQuery().
Expand All @@ -308,9 +311,10 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
SetEndTime(time.Now().Add(time.Second*20)).
Subscribe(env.Client, func(message TopicMessage) {
if string(message.Contents) == bigContents {
finished = true
atomic.StoreInt32(&finished, 1)
}
})

require.NoError(t, err)

resp, err = NewTopicMessageSubmitTransaction().
Expand All @@ -324,7 +328,8 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
require.NoError(t, err)

for {
if err != nil || finished || uint64(time.Since(start).Seconds()) > 60 {
condition := atomic.LoadInt32(&finished) == 1 || uint64(time.Since(start).Seconds()) > 60
if condition {
break
}

Expand All @@ -340,9 +345,10 @@ func TestIntegrationTopicMessageQueryNoStartTime(t *testing.T) {
_, err = resp.SetValidateStatus(true).GetReceipt(env.Client)
require.NoError(t, err)

if !finished {
if atomic.LoadInt32(&finished) == 0 {
err = errors.New("Message was not received within 60 seconds")
}

assert.NoError(t, err)

err = CloseIntegrationTestEnv(env, nil)
Expand All @@ -353,8 +359,8 @@ func TestIntegrationTopicMessageQueryCanExecuteWithTls(t *testing.T) {
client := ClientForNetwork(map[string]AccountID{})
client.SetMirrorNetwork([]string{"mainnet-public.mirrornode.hedera.com:443"})
client.SetTransportSecurity(true)
var finished int32 = 0 // 0 for false, 1 for true

finished := false
start := time.Now()
end := start.Add(5 * time.Second)

Expand All @@ -363,21 +369,20 @@ func TestIntegrationTopicMessageQueryCanExecuteWithTls(t *testing.T) {
SetStartTime(time.Unix(0, 0)).
SetLimit(10).
SetCompletionHandler(func() {
finished = true
atomic.StoreInt32(&finished, 1)
}).
Subscribe(client, func(message TopicMessage) {
})
require.NoError(t, err)

for {
if finished || time.Now().After(end) {
condition := atomic.LoadInt32(&finished) == 1 || time.Now().After(end)
if condition {
break
}

time.Sleep(2 * time.Second)
}

require.True(t, finished)
require.True(t, atomic.LoadInt32(&finished) == 1)

handle.Unsubscribe()
}
Loading