@@ -57,6 +57,33 @@ func TestServerStartStop(t *testing.T) {
57
57
assert .NoError (t , err )
58
58
}
59
59
60
+ func TestServerStartStopWithMiddleware (t * testing.T ) {
61
+ var addedMiddleware atomic.Bool
62
+ assert .False (t , addedMiddleware .Load ())
63
+
64
+ testHTTPMiddleware := func (handler http.Handler ) http.Handler {
65
+ addedMiddleware .Store (true )
66
+ return http .HandlerFunc (
67
+ func (w http.ResponseWriter , r * http.Request ) {
68
+ handler .ServeHTTP (w , r )
69
+ },
70
+ )
71
+ }
72
+
73
+ startSettings := & StartSettings {
74
+ HTTPMiddleware : testHTTPMiddleware ,
75
+ }
76
+
77
+ srv := startServer (t , startSettings )
78
+ assert .True (t , addedMiddleware .Load ())
79
+
80
+ err := srv .Start (* startSettings )
81
+ assert .ErrorIs (t , err , errAlreadyStarted )
82
+
83
+ err = srv .Stop (context .Background ())
84
+ assert .NoError (t , err )
85
+ }
86
+
60
87
func TestServerAddrWithNonZeroPort (t * testing.T ) {
61
88
srv := New (& sharedinternal.NopLogger {})
62
89
require .NotNil (t , srv )
@@ -830,6 +857,109 @@ func TestConnectionAllowsConcurrentWrites(t *testing.T) {
830
857
}
831
858
}
832
859
860
+ func TestServerCallsHTTPMiddlewareOverWebsocket (t * testing.T ) {
861
+ middlewareCalled := int32 (0 )
862
+
863
+ testHTTPMiddleware := func (handler http.Handler ) http.Handler {
864
+ return http .HandlerFunc (
865
+ func (w http.ResponseWriter , r * http.Request ) {
866
+ atomic .AddInt32 (& middlewareCalled , 1 )
867
+ handler .ServeHTTP (w , r )
868
+ },
869
+ )
870
+ }
871
+
872
+ callbacks := CallbacksStruct {
873
+ OnConnectingFunc : func (request * http.Request ) types.ConnectionResponse {
874
+ return types.ConnectionResponse {
875
+ Accept : true ,
876
+ ConnectionCallbacks : ConnectionCallbacksStruct {},
877
+ }
878
+ },
879
+ }
880
+
881
+ // Start a Server
882
+ settings := & StartSettings {
883
+ HTTPMiddleware : testHTTPMiddleware ,
884
+ Settings : Settings {Callbacks : callbacks },
885
+ }
886
+ srv := startServer (t , settings )
887
+ defer func () {
888
+ err := srv .Stop (context .Background ())
889
+ assert .NoError (t , err )
890
+ }()
891
+
892
+ // Connect to the server, ensuring successful connection
893
+ conn , resp , err := dialClient (settings )
894
+ assert .NoError (t , err )
895
+ assert .NotNil (t , conn )
896
+ require .NotNil (t , resp )
897
+ assert .EqualValues (t , 101 , resp .StatusCode )
898
+
899
+ // Verify middleware was called once for the websocket connection
900
+ eventually (t , func () bool { return atomic .LoadInt32 (& middlewareCalled ) == int32 (1 ) })
901
+ assert .Equal (t , int32 (1 ), atomic .LoadInt32 (& middlewareCalled ))
902
+ }
903
+
904
+ func TestServerCallsHTTPMiddlewareOverHTTP (t * testing.T ) {
905
+ middlewareCalled := int32 (0 )
906
+
907
+ testHTTPMiddleware := func (handler http.Handler ) http.Handler {
908
+ return http .HandlerFunc (
909
+ func (w http.ResponseWriter , r * http.Request ) {
910
+ atomic .AddInt32 (& middlewareCalled , 1 )
911
+ handler .ServeHTTP (w , r )
912
+ },
913
+ )
914
+ }
915
+
916
+ callbacks := CallbacksStruct {
917
+ OnConnectingFunc : func (request * http.Request ) types.ConnectionResponse {
918
+ return types.ConnectionResponse {
919
+ Accept : true ,
920
+ ConnectionCallbacks : ConnectionCallbacksStruct {},
921
+ }
922
+ },
923
+ }
924
+
925
+ // Start a Server
926
+ settings := & StartSettings {
927
+ HTTPMiddleware : testHTTPMiddleware ,
928
+ Settings : Settings {Callbacks : callbacks },
929
+ }
930
+ srv := startServer (t , settings )
931
+ defer func () {
932
+ err := srv .Stop (context .Background ())
933
+ assert .NoError (t , err )
934
+ }()
935
+
936
+ // Send an AgentToServer message to the Server
937
+ sendMsg1 := protobufs.AgentToServer {InstanceUid : "01BX5ZZKBKACTAV9WEVGEMMVS1" }
938
+ serializedProtoBytes1 , err := proto .Marshal (& sendMsg1 )
939
+ require .NoError (t , err )
940
+ _ , err = http .Post (
941
+ "http://" + settings .ListenEndpoint + settings .ListenPath ,
942
+ contentTypeProtobuf ,
943
+ bytes .NewReader (serializedProtoBytes1 ),
944
+ )
945
+ require .NoError (t , err )
946
+
947
+ // Send another AgentToServer message to the Server
948
+ sendMsg2 := protobufs.AgentToServer {InstanceUid : "01BX5ZZKBKACTAV9WEVGEMMVRZ" }
949
+ serializedProtoBytes2 , err := proto .Marshal (& sendMsg2 )
950
+ require .NoError (t , err )
951
+ _ , err = http .Post (
952
+ "http://" + settings .ListenEndpoint + settings .ListenPath ,
953
+ contentTypeProtobuf ,
954
+ bytes .NewReader (serializedProtoBytes2 ),
955
+ )
956
+ require .NoError (t , err )
957
+
958
+ // Verify middleware was triggered for each HTTP call
959
+ eventually (t , func () bool { return atomic .LoadInt32 (& middlewareCalled ) == int32 (2 ) })
960
+ assert .Equal (t , int32 (2 ), atomic .LoadInt32 (& middlewareCalled ))
961
+ }
962
+
833
963
func BenchmarkSendToClient (b * testing.B ) {
834
964
clientConnections := []* websocket.Conn {}
835
965
serverConnections := []types.Connection {}
0 commit comments