diff --git a/pkg/transport/sqs/encode_decode.go b/pkg/transport/sqs/encode_decode.go new file mode 100644 index 00000000..323b3c2e --- /dev/null +++ b/pkg/transport/sqs/encode_decode.go @@ -0,0 +1,24 @@ +package sqs + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" +) + +// DecodeRequestFunc extracts a user-domain request object from +// an SQS message object. It is designed to be used in Consumers. +type DecodeRequestFunc func(context.Context, types.Message) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed payload object into +// an SQS message object. It is designed to be used in Producers. +type EncodeRequestFunc func(context.Context, *sqs.SendMessageInput, interface{}) error + +// EncodeResponseFunc encodes the passed response object to +// an SQS message object. It is designed to be used in Consumers. +type EncodeResponseFunc func(context.Context, *sqs.SendMessageInput, interface{}) error + +// DecodeResponseFunc extracts a user-domain response object from +// an SQS message object. It is designed to be used in Producers. +type DecodeResponseFunc func(context.Context, types.Message) (response interface{}, err error) diff --git a/pkg/transport/sqs/publisher.go b/pkg/transport/sqs/publisher.go new file mode 100644 index 00000000..cd2b7dc4 --- /dev/null +++ b/pkg/transport/sqs/publisher.go @@ -0,0 +1,136 @@ +package sqs + +import ( + "context" + "encoding/json" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/go-kit/kit/endpoint" +) + +type contextKey int + +const ( + // ContextKeyResponseQueueURL is the context key that allows fetching + // and setting the response queue URL from and into context. + ContextKeyResponseQueueURL contextKey = iota +) + +type ( + SQSPublisher interface { + Publish(ctx context.Context, message *sqs.SendMessageInput) (*sqs.SendMessageOutput, error) + } + + // Publisher wraps an Publisher client, and provides a method that + // implements endpoint.Endpoint. + Publisher struct { + Handler SQSPublisher + queueURL string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []PublisherRequestFunc + after []PublisherResponseFunc + } +) + +// NewPublisher constructs a usable Publisher for a single remote method. +func NewPublisher( + handler SQSPublisher, + queueURL string, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...PublisherOption, +) *Publisher { + p := &Publisher{ + Handler: handler, + queueURL: queueURL, + enc: enc, + dec: dec, + } + for _, option := range options { + option(p) + } + return p +} + +// PublisherOption sets an optional parameter for clients. +type PublisherOption func(*Publisher) + +// PublisherBefore sets the RequestFuncs that are applied to the outgoing SQS +// request before it's invoked. +func PublisherBefore(before ...PublisherRequestFunc) PublisherOption { + return func(p *Publisher) { p.before = append(p.before, before...) } +} + +// PublisherAfter sets the ClientResponseFuncs applied to the incoming SQS +// request prior to it being decoded. This is useful for obtaining the response +// and adding any information onto the context prior to decoding. +func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { + return func(p *Publisher) { p.after = append(p.after, after...) } +} + +// SetPublisherResponseQueueURL can be used as a before function to add +// provided url as responseQueueURL in context. +func SetPublisherResponseQueueURL(url string) PublisherRequestFunc { + return func(ctx context.Context, _ *sqs.SendMessageInput) context.Context { + return context.WithValue(ctx, ContextKeyResponseQueueURL, url) + } +} + +// Endpoint returns a usable endpoint that invokes the remote endpoint. +func (p Publisher) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + msgInput := sqs.SendMessageInput{ + QueueUrl: &p.queueURL, + } + if err := p.enc(ctx, &msgInput, request); err != nil { + return nil, err + } + + for _, f := range p.before { + ctx = f(ctx, &msgInput) + } + + output, err := p.Handler.Publish(ctx, &msgInput) + if err != nil { + return nil, err + } + + var responseMsg types.Message + for _, f := range p.after { + ctx, responseMsg, err = f(ctx, p.Handler, output) + if err != nil { + return nil, err + } + } + + response, err := p.dec(ctx, responseMsg) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a +// JSON object and loads it as the MessageBody of the sqs.SendMessageInput. +// This can be enough for most JSON over SQS communications. +func EncodeJSONRequest(_ context.Context, msg *sqs.SendMessageInput, request interface{}) error { + b, err := json.Marshal(request) + if err != nil { + return err + } + + msg.MessageBody = aws.String(string(b)) + + return nil +} + +// NoResponseDecode is a DecodeResponseFunc that can be used when no response is needed. +// It returns nil value and nil error. +func NoResponseDecode(_ context.Context, _ types.Message) (interface{}, error) { + return nil, nil +} diff --git a/pkg/transport/sqs/publisher_test.go b/pkg/transport/sqs/publisher_test.go new file mode 100644 index 00000000..7819c9d5 --- /dev/null +++ b/pkg/transport/sqs/publisher_test.go @@ -0,0 +1,329 @@ +package sqs + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" +) + +type testReq struct { + Squadron int `json:"s"` +} + +type testRes struct { + Squadron int `json:"s"` + Name string `json:"n"` +} + +var names = map[int]string{ + 424: "tiger", + 426: "thunderbird", + 429: "bison", + 436: "tusker", + 437: "husky", +} + +// mockClient is a mock of SQS Handler. +type mockClient struct { + SQSPublisher + SQSSubscriber + err error + sendOutputChan chan types.Message + receiveOutputChan chan *sqs.ReceiveMessageOutput + sendMsgID string + deleteError error +} + +func (mock *mockClient) Publish(ctx context.Context, input *sqs.SendMessageInput) (*sqs.SendMessageOutput, error) { + if input != nil && input.MessageBody != nil && *input.MessageBody != "" { + go func() { + mock.receiveOutputChan <- &sqs.ReceiveMessageOutput{ + Messages: []types.Message{ + { + MessageAttributes: input.MessageAttributes, + Body: input.MessageBody, + MessageId: aws.String(mock.sendMsgID), + }, + }, + } + }() + return &sqs.SendMessageOutput{MessageId: aws.String(mock.sendMsgID)}, nil + } + // Add logic to allow context errors. + for { + select { + case d := <-mock.sendOutputChan: + return &sqs.SendMessageOutput{MessageId: d.MessageId}, mock.err + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// TestBadEncode tests if encode errors are handled properly. +func TestBadEncode(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + } + pub := NewPublisher( + mock, + queueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, + func(context.Context, types.Message) (response interface{}, err error) { return struct{}{}, nil }, + ) + errChan := make(chan error, 1) + var err error + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + + }() + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestBadDecode tests if decode errors are handled properly. +func TestBadDecode(t *testing.T) { + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + } + go func() { + mock.sendOutputChan <- types.Message{ + MessageId: aws.String("someMsgID"), + } + }() + + queueURL := "someURL" + pub := NewPublisher( + mock, + queueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + func(context.Context, types.Message) (response interface{}, err error) { + return struct{}{}, errors.New("err!") + }, + PublisherAfter(func( + ctx context.Context, _ SQSPublisher, msg *sqs.SendMessageOutput) (context.Context, types.Message, error) { + // Set the actual response for the request. + return ctx, types.Message{Body: aws.String("someMsgContent")}, nil + }), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSuccessfulPublisher ensures that the producer mechanisms work. +func TestSuccessfulPublisher(t *testing.T) { + mockReq := testReq{437} + mockRes := testRes{ + Squadron: mockReq.Squadron, + Name: names[mockReq.Squadron], + } + b, err := json.Marshal(mockRes) + if err != nil { + t.Fatal(err) + } + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + sendMsgID: "someMsgID", + } + go func() { + mock.sendOutputChan <- types.Message{ + MessageId: aws.String("someMsgID"), + } + }() + + queueURL := "someURL" + pub := NewPublisher( + mock, + queueURL, + EncodeJSONRequest, + func(_ context.Context, msg types.Message) (interface{}, error) { + response := testRes{} + if err := json.Unmarshal([]byte(*msg.Body), &response); err != nil { + return nil, err + } + return response, nil + }, + PublisherAfter(func( + ctx context.Context, _ SQSPublisher, msg *sqs.SendMessageOutput) (context.Context, types.Message, error) { + // Sets the actual response for the request. + if *msg.MessageId == "someMsgID" { + return ctx, types.Message{Body: aws.String(string(b))}, nil + } + return nil, types.Message{}, fmt.Errorf("Did not receive expected SendMessageOutput") + }), + ) + var res testRes + var ok bool + resChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + go func() { + r, pubErr := pub.Endpoint()(context.Background(), mockReq) + if pubErr != nil { + errChan <- pubErr + } else { + resChan <- r + } + }() + + select { + case response := <-resChan: + res, ok = response.(testRes) + if !ok { + t.Error("failed to assert endpoint response type") + } + break + + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err != nil { + t.Fatal(err) + } + if want, have := mockRes.Name, res.Name; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSuccessfulPublisherNoResponse ensures that the producer response mechanism works. +func TestSuccessfulPublisherNoResponse(t *testing.T) { + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + sendMsgID: "someMsgID", + } + + queueURL := "someURL" + pub := NewPublisher( + mock, + queueURL, + EncodeJSONRequest, + NoResponseDecode, + ) + var err error + errChan := make(chan error, 1) + finishChan := make(chan bool, 1) + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + if pubErr != nil { + errChan <- pubErr + } else { + finishChan <- true + } + }() + + select { + case <-finishChan: + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} + +// TestPublisherWithBefore adds a PublisherBefore function that adds responseQueueURL to context, +// and another on that adds it as a message attribute to outgoing message. +// This test ensures that setting multiple before functions work as expected +// and that SetPublisherResponseQueueURL works as expected. +func TestPublisherWithBefore(t *testing.T) { + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + sendMsgID: "someMsgID", + } + + queueURL := "someURL" + responseQueueURL := "someOtherURL" + pub := NewPublisher( + mock, + queueURL, + EncodeJSONRequest, + NoResponseDecode, + PublisherBefore(SetPublisherResponseQueueURL(responseQueueURL)), + PublisherBefore(func(c context.Context, s *sqs.SendMessageInput) context.Context { + responseQueueURL := c.Value(ContextKeyResponseQueueURL).(string) + if s.MessageAttributes == nil { + s.MessageAttributes = make(map[string]types.MessageAttributeValue) + } + s.MessageAttributes["responseQueueURL"] = types.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: &responseQueueURL, + } + return c + }), + ) + var err error + errChan := make(chan error, 1) + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + if pubErr != nil { + errChan <- pubErr + } + }() + + want := types.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: &responseQueueURL, + } + + select { + case receiveOutput := <-mock.receiveOutputChan: + if len(receiveOutput.Messages) != 1 { + t.Errorf("published %d messages instead of 1", len(receiveOutput.Messages)) + } + if have, exists := receiveOutput.Messages[0].MessageAttributes["responseQueueURL"]; !exists { + t.Errorf("expected MessageAttributes responseQueueURL not found") + } else if *have.StringValue != responseQueueURL || *have.DataType != "String" { + t.Errorf("want %v, have %v", want, have) + } + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} diff --git a/pkg/transport/sqs/request_response_funcs.go b/pkg/transport/sqs/request_response_funcs.go new file mode 100644 index 00000000..24db77d7 --- /dev/null +++ b/pkg/transport/sqs/request_response_funcs.go @@ -0,0 +1,36 @@ +package sqs + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" +) + +// SubscriberRequestFunc may take information from a consumer request result and +// put it into a request context. In Subscribers, RequestFuncs are executed prior +// to invoking the endpoint. +// use cases eg. in Subscriber : extract message information into context. +type SubscriberRequestFunc func( + ctx context.Context, cancel context.CancelFunc, message types.Message) context.Context + +// PublisherRequestFunc may take information from a producer request and put it into a +// request context, or add some informations to SendMessageInput. In Publishers, +// RequestFuncs are executed prior to publishing the message but after encoding. +// use cases eg. in Publisher : enforce some message attributes to SendMessageInput. +type PublisherRequestFunc func(ctx context.Context, input *sqs.SendMessageInput) context.Context + +// SubscriberResponseFunc may take information from a request context and use it to +// manipulate a Publisher. SubscriberResponseFunc are only executed in +// consumers, after invoking the endpoint but prior to publishing a reply. +// use cases eg. : Pipe information from request message to response MessageInput, +// delete msg from queue or update leftMsgs slice. +type SubscriberResponseFunc func( + ctx context.Context, cancel context.CancelFunc, message types.Message, resp interface{}) context.Context + +// PublisherResponseFunc may take information from an sqs.SendMessageOutput and +// fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. +// PublisherResponseFunc are only executed in producers, after a request has been made, +// but prior to its response being decoded. So this is the perfect place to fetch actual response. +type PublisherResponseFunc func( + context.Context, SQSPublisher, *sqs.SendMessageOutput) (context.Context, types.Message, error) diff --git a/pkg/transport/sqs/subscriber.go b/pkg/transport/sqs/subscriber.go new file mode 100644 index 00000000..2db871b5 --- /dev/null +++ b/pkg/transport/sqs/subscriber.go @@ -0,0 +1,222 @@ +package sqs + +import ( + "context" + "encoding/json" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/transport" +) + +type ( + SQSSubscriber interface { + ReceiveMessages(ctx context.Context, input *sqs.ReceiveMessageInput) (*sqs.ReceiveMessageOutput, error) + ChangeMessageVisibility( + ctx context.Context, input *sqs.ChangeMessageVisibilityInput) (*sqs.ChangeMessageVisibilityOutput, error) + DeleteMessage(ctx context.Context, input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) + } + + // Subscriber wraps an endpoint and provides a handler for SQS messages. + Subscriber struct { + sqsClient SQSSubscriber + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + queueURL string + before []SubscriberRequestFunc + after []SubscriberResponseFunc + errorEncoder ErrorEncoder + finalizer []SubscriberFinalizerFunc + errorHandler transport.ErrorHandler + } +) + +// NewSubscriber constructs a new Subscriber, which provides a Consume method +// and message handlers that wrap the provided endpoint. +func NewSubscriber( + sqsClient SQSSubscriber, + e endpoint.Endpoint, + dec DecodeRequestFunc, + enc EncodeResponseFunc, + queueURL string, + options ...SubscriberOption, +) *Subscriber { + s := &Subscriber{ + sqsClient: sqsClient, + e: e, + dec: dec, + enc: enc, + queueURL: queueURL, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + } + for _, option := range options { + option(s) + } + return s +} + +// SubscriberOption sets an optional parameter for consumers. +type SubscriberOption func(*Subscriber) + +// SubscriberBefore functions are executed on the producer request object before the +// request is decoded. +func SubscriberBefore(before ...SubscriberRequestFunc) SubscriberOption { + return func(s *Subscriber) { s.before = append(s.before, before...) } +} + +// SubscriberAfter functions are executed on the consumer reply after the +// endpoint is invoked, but before anything is published to the reply. +func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { + return func(s *Subscriber) { s.after = append(s.after, after...) } +} + +// SubscriberErrorEncoder is used to encode errors to the consumer reply +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting. By default, +// errors will be published with the DefaultErrorEncoder. +func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption { + return func(s *Subscriber) { s.errorEncoder = ee } +} + +// SubscriberErrorHandler is used to handle non-terminal errors. By default, non-terminal errors +// are ignored. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom SubscriberErrorEncoder which has access to the context. +func SubscriberErrorHandler(errorHandler transport.ErrorHandler) SubscriberOption { + return func(s *Subscriber) { s.errorHandler = errorHandler } +} + +// SubscriberFinalizer is executed once all the received SQS messages are done being processed. +// By default, no finalizer is registered. +func SubscriberFinalizer(f ...SubscriberFinalizerFunc) SubscriberOption { + return func(s *Subscriber) { s.finalizer = f } +} + +// SubscriberDeleteMessageBefore returns a SubscriberOption that appends a function +// that delete the message from queue to the list of consumer's before functions. +func SubscriberDeleteMessageBefore() SubscriberOption { + return func(s *Subscriber) { + deleteBefore := func(ctx context.Context, cancel context.CancelFunc, msg types.Message) context.Context { + if err := deleteMessage(ctx, s.sqsClient, s.queueURL, msg); err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, msg, s.sqsClient) + cancel() + } + return ctx + } + s.before = append(s.before, deleteBefore) + } +} + +// SubscriberDeleteMessageAfter returns a SubscriberOption that appends a function +// that delete a message from queue to the list of consumer's after functions. +func SubscriberDeleteMessageAfter() SubscriberOption { + return func(s *Subscriber) { + deleteAfter := func( + ctx context.Context, cancel context.CancelFunc, msg types.Message, _ interface{}) context.Context { + if err := deleteMessage(ctx, s.sqsClient, s.queueURL, msg); err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, msg, s.sqsClient) + cancel() + } + return ctx + } + s.after = append(s.after, deleteAfter) + } +} + +// ServeMessage serves an SQS message. +func (s Subscriber) ServeMessage(ctx context.Context) func(msg types.Message) error { + return func(msg types.Message) error { + newCtx, cancel := context.WithCancel(ctx) + defer cancel() + + if len(s.finalizer) > 0 { + defer func() { + for _, f := range s.finalizer { + f(newCtx, msg) + } + }() + } + + for _, f := range s.before { + newCtx = f(newCtx, cancel, msg) + } + + req, err := s.dec(newCtx, msg) + if err != nil { + s.errorHandler.Handle(newCtx, err) + s.errorEncoder(newCtx, err, msg, s.sqsClient) + return err + } + + response, err := s.e(newCtx, req) + if err != nil { + s.errorHandler.Handle(newCtx, err) + s.errorEncoder(newCtx, err, msg, s.sqsClient) + return err + } + + for _, f := range s.after { + newCtx = f(newCtx, cancel, msg, response) + } + + return nil + } +} + +// ErrorEncoder is responsible for encoding an error to the consumer's reply. +// Users are encouraged to use custom ErrorEncoders to encode errors to +// their replies, and will likely want to pass and check for their own error +// types. +type ErrorEncoder func(ctx context.Context, err error, req types.Message, sqsClient SQSSubscriber) + +// SubscriberFinalizerFunc can be used to perform work at the end of a request +// from a producer, after the response has been written to the producer. The +// principal intended use is for request logging. +// Can also be used to delete messages once fully proccessed. +type SubscriberFinalizerFunc func(ctx context.Context, msg types.Message) + +// DefaultErrorEncoder simply ignores the message. It does not reply. +func DefaultErrorEncoder(context.Context, error, types.Message, SQSSubscriber) { +} + +// SubscriberNackMessageErrorEncoder can be used to perform an immediate nack on the message. +func SubscriberNackMessageErrorEncoder() SubscriberOption { + return func(s *Subscriber) { + nackErrorHandler := func(ctx context.Context, err error, msg types.Message, sqsClient SQSSubscriber) { + _, sqsErr := sqsClient.ChangeMessageVisibility(ctx, &sqs.ChangeMessageVisibilityInput{ + QueueUrl: &s.queueURL, + ReceiptHandle: msg.ReceiptHandle, + VisibilityTimeout: 1, + }) + if sqsErr != nil { + s.errorHandler.Handle(ctx, sqsErr) + } + } + s.errorEncoder = nackErrorHandler + } +} + +func deleteMessage(ctx context.Context, sqsClient SQSSubscriber, queueURL string, msg types.Message) error { + _, err := sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{ + QueueUrl: &queueURL, + ReceiptHandle: msg.ReceiptHandle, + }) + return err +} + +// EncodeJSONResponse marshals response as json and loads it into an sqs.SendMessageInput MessageBody. +func EncodeJSONResponse(_ context.Context, input *sqs.SendMessageInput, response interface{}) error { + payload, err := json.Marshal(response) + if err != nil { + return err + } + input.MessageBody = aws.String(string(payload)) + return nil +} diff --git a/pkg/transport/sqs/subscriber_test.go b/pkg/transport/sqs/subscriber_test.go new file mode 100644 index 00000000..d6d1d5fa --- /dev/null +++ b/pkg/transport/sqs/subscriber_test.go @@ -0,0 +1,433 @@ +package sqs + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" +) + +const ( + testErrMessage = "err!" +) + +var ( + errTypeAssertion = errors.New("type assertion error") +) + +func (mock *mockClient) ReceiveMessage( + ctx context.Context, input *sqs.ReceiveMessageInput) (*sqs.ReceiveMessageOutput, error) { + // Add logic to allow context errors. + for { + select { + case d := <-mock.receiveOutputChan: + return d, mock.err + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (mock *mockClient) DeleteMessage( + ctx context.Context, input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) { + return nil, mock.deleteError +} + +// TestSubscriberDeleteBefore checks if deleteMessage is set properly using subscriber options. +func TestSubscriberDeleteBefore(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + deleteError: fmt.Errorf("delete err!"), + } + errEncoder := SubscriberErrorEncoder(func( + ctx context.Context, err error, req types.Message, sqsClient SQSSubscriber) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, err := json.Marshal(publishError) + if err != nil { + t.Fatal(err) + } + + publisher := sqsClient.(*mockClient) + _, err = publisher.Publish(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + if err != nil { + t.Fatal(err) + } + }) + subscriber := NewSubscriber(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, types.Message) (interface{}, error) { return nil, nil }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + queueURL, + errEncoder, + SubscriberDeleteMessageBefore(), + ) + + err := subscriber.ServeMessage(context.Background())(types.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) + if err != nil { + t.Fatal(err) + } + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOutputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeSubscriberError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "delete err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberBadDecode checks if decoder errors are handled properly. +func TestSubscriberBadDecode(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + } + errEncoder := SubscriberErrorEncoder(func( + ctx context.Context, err error, req types.Message, sqsClient SQSSubscriber) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, err := json.Marshal(publishError) + if err != nil { + t.Fatal(err) + } + + publisher := sqsClient.(*mockClient) + _, err = publisher.Publish(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + if err != nil { + t.Fatal(err) + } + }) + subscriber := NewSubscriber(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, types.Message) (interface{}, error) { return nil, errors.New(testErrMessage) }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + queueURL, + errEncoder, + ) + + err := subscriber.ServeMessage(context.Background())(types.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) + if err == nil { + t.Errorf("expected error") + } + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOutputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeSubscriberError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := testErrMessage, res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberBadEndpoint checks if endpoint errors are handled properly. +func TestSubscriberBadEndpoint(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + } + errEncoder := SubscriberErrorEncoder(func( + ctx context.Context, err error, req types.Message, sqsClient SQSSubscriber) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, err := json.Marshal(publishError) + if err != nil { + t.Fatal(err) + } + + publisher := sqsClient.(*mockClient) + _, err = publisher.Publish(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + if err != nil { + t.Fatal(err) + } + }) + subscriber := NewSubscriber(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New(testErrMessage) }, + func(context.Context, types.Message) (interface{}, error) { return nil, nil }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + queueURL, + errEncoder, + ) + + err := subscriber.ServeMessage(context.Background())(types.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) + if err == nil { + t.Errorf("expected error") + } + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOutputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeSubscriberError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := testErrMessage, res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberSuccess checks if subscriber responds correctly to message. +func TestSubscriberSuccess(t *testing.T) { + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + } + subscriber := NewSubscriber(mock, + testEndpoint, + testReqDecoderfunc, + EncodeJSONResponse, + queueURL, + SubscriberAfter(func( + ctx context.Context, cancel context.CancelFunc, msg types.Message, resp interface{}) context.Context { + _, err = mock.Publish(context.Background(), &sqs.SendMessageInput{ + MessageBody: msg.Body, + }) + if err != nil { + t.Fatal(err) + } + + return ctx + }), + ) + + err = subscriber.ServeMessage(context.Background())(types.Message{ + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }) + if err != nil { + t.Fatal(err) + } + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOutputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeResponse(receiveOutput) + if err != nil { + t.Fatal(err) + } + want := testRes{ + Squadron: 436, + } + if have := res; want != have { + t.Errorf("want %v, have %v", want, have) + } +} + +// TestSubscriberSuccessNoReply checks if subscriber processes correctly message +// without sending response. +func TestSubscriberSuccessNoReply(t *testing.T) { + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + } + subscriber := NewSubscriber(mock, + testEndpoint, + testReqDecoderfunc, + EncodeJSONResponse, + queueURL, + ) + + err = subscriber.ServeMessage(context.Background())(types.Message{ + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }) + if err != nil { + t.Fatal(err) + } + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOutputChan: + t.Errorf("received output when none was expected, have %v", receiveOutput) + return + + case <-time.After(200 * time.Millisecond): + // As expected, we did not receive any response from subscriber. + return + } +} + +// TestSubscriberAfter checks if subscriber after is called as expected. +// Here after is used to transfer some info from received message in response. +func TestSubscriberAfter(t *testing.T) { + obj1 := testReq{ + Squadron: 436, + } + b1, err := json.Marshal(obj1) + if err != nil { + t.Fatal(err) + } + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan types.Message), + receiveOutputChan: make(chan *sqs.ReceiveMessageOutput), + } + correlationID := "test" + msg := types.Message{ + Body: aws.String(string(b1)), + MessageId: aws.String("fakeMsgID1"), + MessageAttributes: map[string]types.MessageAttributeValue{ + "correlationID": { + DataType: aws.String("String"), + StringValue: &correlationID, + }, + }, + } + subscriber := NewSubscriber(mock, + testEndpoint, + testReqDecoderfunc, + EncodeJSONResponse, + queueURL, + SubscriberAfter(func( + ctx context.Context, cancel context.CancelFunc, msg types.Message, resp interface{}) context.Context { + _, err := mock.Publish(ctx, &sqs.SendMessageInput{ + MessageBody: msg.Body, + MessageAttributes: msg.MessageAttributes, + }) + if err != nil { + t.Fatal(err) + } + + return ctx + }), + ) + ctx := context.Background() + err = subscriber.ServeMessage(ctx)(msg) + if err != nil { + t.Fatal(err) + } + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOutputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + if len(receiveOutput.Messages) != 1 { + t.Errorf("received %d messages instead of 1", len(receiveOutput.Messages)) + } + if correlationIDAttribute, exists := receiveOutput.Messages[0].MessageAttributes["correlationID"]; exists { + if have := correlationIDAttribute.StringValue; *have != correlationID { + t.Errorf("have %s, want %s", *have, correlationID) + } + } else { + t.Errorf("expected message attribute with key correlationID in response, but it was not found") + } +} + +type sqsError struct { + Err string `json:"err"` + MsgID string `json:"msgID"` +} + +func decodeSubscriberError(receiveOutput *sqs.ReceiveMessageOutput) (sqsError, error) { + receivedError := sqsError{} + err := json.Unmarshal([]byte(*receiveOutput.Messages[0].Body), &receivedError) + return receivedError, err +} + +func testEndpoint(ctx context.Context, request interface{}) (interface{}, error) { + req, ok := request.(testReq) + if !ok { + return nil, errTypeAssertion + } + name, prs := names[req.Squadron] + if !prs { + return nil, errors.New("unknown squadron name") + } + res := testRes{ + Squadron: req.Squadron, + Name: name, + } + return res, nil +} + +func testReqDecoderfunc(_ context.Context, msg types.Message) (interface{}, error) { + var obj testReq + err := json.Unmarshal([]byte(*msg.Body), &obj) + return obj, err +} + +func decodeResponse(receiveOutput *sqs.ReceiveMessageOutput) (interface{}, error) { + if len(receiveOutput.Messages) != 1 { + return nil, fmt.Errorf("Error : received %d messages instead of 1", len(receiveOutput.Messages)) + } + resp := testRes{} + err := json.Unmarshal([]byte(*receiveOutput.Messages[0].Body), &resp) + return resp, err +}