@@ -68,6 +68,7 @@ func New(logger types.Logger) *server {
6868
6969func (s * server ) Attach (settings Settings ) (HTTPHandlerFunc , ConnContext , error ) {
7070 s .settings = settings
71+ s .settings .Callbacks .SetDefaults ()
7172 s .wsUpgrader = websocket.Upgrader {
7273 EnableCompression : settings .EnableCompression ,
7374 }
@@ -169,26 +170,25 @@ func (s *server) Addr() net.Addr {
169170
170171func (s * server ) httpHandler (w http.ResponseWriter , req * http.Request ) {
171172 var connectionCallbacks serverTypes.ConnectionCallbacks
172- if s .settings .Callbacks != nil {
173- resp := s .settings .Callbacks .OnConnecting (req )
174- if ! resp .Accept {
175- // HTTP connection is not accepted. Set the response headers.
176- for k , v := range resp .HTTPResponseHeader {
177- w .Header ().Set (k , v )
178- }
179- // And write the response status code.
180- w .WriteHeader (resp .HTTPStatusCode )
181- return
173+ resp := s .settings .Callbacks .OnConnecting (req )
174+ if ! resp .Accept {
175+ // HTTP connection is not accepted. Set the response headers.
176+ for k , v := range resp .HTTPResponseHeader {
177+ w .Header ().Set (k , v )
182178 }
183- // use connection-specific handler provided by ConnectionResponse
184- connectionCallbacks = resp .ConnectionCallbacks
179+ // And write the response status code.
180+ w .WriteHeader (resp .HTTPStatusCode )
181+ return
185182 }
183+ // use connection-specific handler provided by ConnectionResponse
184+ connectionCallbacks = resp .ConnectionCallbacks
185+ connectionCallbacks .SetDefaults ()
186186
187187 // HTTP connection is accepted. Check if it is a plain HTTP request.
188188
189189 if req .Header .Get (headerContentType ) == contentTypeProtobuf {
190190 // Yes, a plain HTTP request.
191- s .handlePlainHTTPRequest (req , w , connectionCallbacks )
191+ s .handlePlainHTTPRequest (req , w , & connectionCallbacks )
192192 return
193193 }
194194
@@ -201,10 +201,10 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) {
201201
202202 // Return from this func to reduce memory usage.
203203 // Handle the connection on a separate goroutine.
204- go s .handleWSConnection (req .Context (), conn , connectionCallbacks )
204+ go s .handleWSConnection (req .Context (), conn , & connectionCallbacks )
205205}
206206
207- func (s * server ) handleWSConnection (reqCtx context.Context , wsConn * websocket.Conn , connectionCallbacks serverTypes.ConnectionCallbacks ) {
207+ func (s * server ) handleWSConnection (reqCtx context.Context , wsConn * websocket.Conn , connectionCallbacks * serverTypes.ConnectionCallbacks ) {
208208 agentConn := wsConnection {wsConn : wsConn , connMutex : & sync.Mutex {}}
209209
210210 defer func () {
@@ -216,14 +216,10 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co
216216 }
217217 }()
218218
219- if connectionCallbacks != nil {
220- connectionCallbacks .OnConnectionClose (agentConn )
221- }
219+ connectionCallbacks .OnConnectionClose (agentConn )
222220 }()
223221
224- if connectionCallbacks != nil {
225- connectionCallbacks .OnConnected (reqCtx , agentConn )
226- }
222+ connectionCallbacks .OnConnected (reqCtx , agentConn )
227223
228224 sentCustomCapabilities := false
229225
@@ -254,21 +250,19 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co
254250 continue
255251 }
256252
257- if connectionCallbacks != nil {
258- response := connectionCallbacks .OnMessage (msgContext , agentConn , & request )
259- if len (response .InstanceUid ) == 0 {
260- response .InstanceUid = request .InstanceUid
261- }
262- if ! sentCustomCapabilities {
263- response .CustomCapabilities = & protobufs.CustomCapabilities {
264- Capabilities : s .settings .CustomCapabilities ,
265- }
266- sentCustomCapabilities = true
267- }
268- err = agentConn .Send (msgContext , response )
269- if err != nil {
270- s .logger .Errorf (msgContext , "Cannot send message to WebSocket: %v" , err )
253+ response := connectionCallbacks .OnMessage (msgContext , agentConn , & request )
254+ if len (response .InstanceUid ) == 0 {
255+ response .InstanceUid = request .InstanceUid
256+ }
257+ if ! sentCustomCapabilities {
258+ response .CustomCapabilities = & protobufs.CustomCapabilities {
259+ Capabilities : s .settings .CustomCapabilities ,
271260 }
261+ sentCustomCapabilities = true
262+ }
263+ err = agentConn .Send (msgContext , response )
264+ if err != nil {
265+ s .logger .Errorf (msgContext , "Cannot send message to WebSocket: %v" , err )
272266 }
273267 }
274268}
@@ -310,7 +304,7 @@ func compressGzip(data []byte) ([]byte, error) {
310304 return buf .Bytes (), nil
311305}
312306
313- func (s * server ) handlePlainHTTPRequest (req * http.Request , w http.ResponseWriter , connectionCallbacks serverTypes.ConnectionCallbacks ) {
307+ func (s * server ) handlePlainHTTPRequest (req * http.Request , w http.ResponseWriter , connectionCallbacks * serverTypes.ConnectionCallbacks ) {
314308 bodyBytes , err := s .readReqBody (req )
315309 if err != nil {
316310 s .logger .Debugf (req .Context (), "Cannot read HTTP body: %v" , err )
@@ -331,11 +325,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
331325 conn : connFromRequest (req ),
332326 }
333327
334- if connectionCallbacks == nil {
335- w .WriteHeader (http .StatusInternalServerError )
336- return
337- }
338-
339328 connectionCallbacks .OnConnected (req .Context (), agentConn )
340329
341330 defer func () {
0 commit comments