diff --git a/stompserver/errors.go b/stompserver/errors.go index 4f8c9f4..ab8b875 100644 --- a/stompserver/errors.go +++ b/stompserver/errors.go @@ -11,6 +11,7 @@ const ( invalidSubscriptionError = stompErrorMessage("invalid subscription") invalidFrameError = stompErrorMessage("invalid frame") invalidHeaderError = stompErrorMessage("invalid frame header") + invalidSendDestinationError = stompErrorMessage("invalid send destination") ) type stompErrorMessage string diff --git a/stompserver/stomp_connection.go b/stompserver/stomp_connection.go index 95207aa..afa0240 100644 --- a/stompserver/stomp_connection.go +++ b/stompserver/stomp_connection.go @@ -359,16 +359,22 @@ func (conn *stompConn) handleSend(f *frame.Frame) error { return unsupportedStompCommandError } - err := conn.sendReceiptResponse(f) - if err != nil { - return err - } - + // no destination triggers an error dest, ok := f.Header.Contains(frame.Destination) if !ok { return invalidFrameError } + // reject SENDing directly to non-request channels by clients + if !conn.config.IsAppRequestDestination(f.Header.Get(frame.Destination)) { + return invalidSendDestinationError + } + + err := conn.sendReceiptResponse(f) + if err != nil { + return err + } + f.Command = frame.MESSAGE conn.events <- &ConnEvent{ ConnId: conn.GetId(), diff --git a/stompserver/stomp_connection_test.go b/stompserver/stomp_connection_test.go index ddd045c..5b87ef1 100644 --- a/stompserver/stomp_connection_test.go +++ b/stompserver/stomp_connection_test.go @@ -415,20 +415,16 @@ func TestStompConn_Subscribe(t *testing.T) { frame.Id, "sub-id", frame.Destination, "/topic/test") - rawConn.incomingFrames <- frame.New(frame.SEND, frame.Destination, "/topic/dest") - - // verify that there will be no SubscribeToTopic con event for the - // the second request. - e = <- events - assert.Equal(t, e.eventType, IncomingMessage) + // verify that there was no second subscription created for the same subscription id + assert.Equal(t, e.sub, stompConn.subscriptions["sub-id"]) } func TestStompConn_SendNotConnected(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) rawConn.incomingFrames <- frame.New( frame.SEND, - frame.Destination, "/topic/test") + frame.Destination, "/pub/test") e := <- events assert.Equal(t, e.eventType, ConnectionClosed) @@ -439,7 +435,7 @@ func TestStompConn_SendNotConnected(t *testing.T) { } func TestStompConn_SendMissingDestinationHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) rawConn.SendConnectFrame() @@ -458,15 +454,36 @@ func TestStompConn_SendMissingDestinationHeader(t *testing.T) { assert.Equal(t, stompConn.state, closed) } +func TestStompConn_Send_InvalidSend(t *testing.T) { + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) + + rawConn.SendConnectFrame() + + e := <- events + assert.Equal(t, e.eventType, ConnectionEstablished) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) + + // try sending a frame to a topic channel directly not request channel + rawConn.incomingFrames <- frame.New(frame.SEND, + frame.Destination, "/topic/test") + e = <- events + + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, + frame.Message, invalidSendDestinationError.Error()), true) +} + func TestStompConn_Send(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) rawConn.SendConnectFrame() e := <- events assert.Equal(t, e.eventType, ConnectionEstablished) - msgF := frame.New(frame.SEND, frame.Destination, "/topic/test") + msgF := frame.New(frame.SEND, frame.Destination, "/pub/test") rawConn.incomingFrames <- msgF @@ -476,7 +493,7 @@ func TestStompConn_Send(t *testing.T) { assert.Equal(t, e.frame.Command, frame.MESSAGE) rawConn.incomingFrames <- frame.New(frame.SEND, - frame.Destination, "/topic/test", frame.Receipt, "receipt-id") + frame.Destination, "/pub/test", frame.Receipt, "receipt-id") e = <- events assert.Equal(t, e.eventType, IncomingMessage) @@ -710,7 +727,7 @@ func TestStompConn_WriteErrorDuringConnect(t *testing.T) { } func TestStompConn_WriteErrorDuringSend(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) rawConn.SendConnectFrame() @@ -720,7 +737,7 @@ func TestStompConn_WriteErrorDuringSend(t *testing.T) { rawConn.nextWriteErr = errors.New("write error") rawConn.incomingFrames <- frame.New( frame.SEND, - frame.Destination, "/topic", + frame.Destination, "/pub/", frame.Receipt, "receipt-id") e = <- events