@@ -21,13 +21,13 @@ import (
2121 "github.com/TwiN/gocache/v2"
2222 "github.com/TwiN/logr"
2323 "github.com/TwiN/whois"
24+ "github.com/gorilla/websocket"
2425 "github.com/ishidawataru/sctp"
2526 "github.com/miekg/dns"
2627 ping "github.com/prometheus-community/pro-bing"
2728 "github.com/registrobr/rdap"
2829 "github.com/registrobr/rdap/protocol"
2930 "golang.org/x/crypto/ssh"
30- "golang.org/x/net/websocket"
3131)
3232
3333const (
@@ -394,48 +394,53 @@ func ShouldRunPingerAsPrivileged() bool {
394394// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
395395func QueryWebSocket (address , body string , headers map [string ]string , config * Config ) (bool , []byte , error ) {
396396 const (
397- Origin = "http://localhost/"
398- MaximumMessageSize = 1024 // in bytes
397+ Origin = "http://localhost/"
399398 )
400- wsConfig , err := websocket .NewConfig (address , Origin )
401- if err != nil {
402- return false , nil , fmt .Errorf ("error configuring websocket connection: %w" , err )
403- }
404- if headers != nil {
405- if wsConfig .Header == nil {
406- wsConfig .Header = make (http.Header )
407- }
408- for name , value := range headers {
409- wsConfig .Header .Set (name , value )
399+ var (
400+ dialer = websocket.Dialer {
401+ EnableCompression : true ,
410402 }
403+ wsHeaders = make (http.Header )
404+ )
405+
406+ wsHeaders .Set ("Origin" , Origin )
407+ for name , value := range headers {
408+ wsHeaders .Set (name , value )
411409 }
410+
411+ ctx := context .Background ()
412412 if config != nil {
413- wsConfig .Dialer = & net.Dialer {Timeout : config .Timeout }
414- wsConfig .TlsConfig = & tls.Config {
413+ if config .Timeout > 0 {
414+ var cancel context.CancelFunc
415+ ctx , cancel = context .WithTimeout (ctx , config .Timeout )
416+ defer cancel ()
417+ }
418+ dialer .TLSClientConfig = & tls.Config {
415419 InsecureSkipVerify : config .Insecure ,
416420 }
417421 if config .HasTLSConfig () && config .TLS .isValid () == nil {
418- wsConfig . TlsConfig = configureTLS (wsConfig . TlsConfig , * config .TLS )
422+ dialer . TLSClientConfig = configureTLS (dialer . TLSClientConfig , * config .TLS )
419423 }
420424 }
421425 // Dial URL
422- ws , err := websocket . DialConfig ( wsConfig )
426+ ws , _ , err := dialer . DialContext ( ctx , address , wsHeaders )
423427 if err != nil {
424428 return false , nil , fmt .Errorf ("error dialing websocket: %w" , err )
425429 }
426430 defer ws .Close ()
427431 body = parseLocalAddressPlaceholder (body , ws .LocalAddr ())
428432 // Write message
429- if _ , err := ws .Write ( []byte (body )); err != nil {
433+ if err := ws .WriteMessage ( websocket . TextMessage , []byte (body )); err != nil {
430434 return false , nil , fmt .Errorf ("error writing websocket body: %w" , err )
431435 }
432436 // Read message
433- var n int
434- msg := make ([]byte , MaximumMessageSize )
435- if n , err = ws .Read (msg ); err != nil {
437+ msgType , msg , err := ws .ReadMessage ()
438+ if err != nil {
436439 return false , nil , fmt .Errorf ("error reading websocket message: %w" , err )
440+ } else if msgType != websocket .TextMessage && msgType != websocket .BinaryMessage {
441+ return false , nil , fmt .Errorf ("unexpected websocket message type: %d, expected %d or %d" , msgType , websocket .TextMessage , websocket .BinaryMessage )
437442 }
438- return true , msg [: n ] , nil
443+ return true , msg , nil
439444}
440445
441446func QueryDNS (queryType , queryName , url string ) (connected bool , dnsRcode string , body []byte , err error ) {
0 commit comments