From 15a805561747f27e99106666e081e5eea73fd357 Mon Sep 17 00:00:00 2001 From: Yaroslav Date: Thu, 18 Dec 2025 23:44:44 +0000 Subject: [PATCH] fix(client): Switch websocket library (#1423) * fix(websocket): switch to gorilla/websocket * fix(client): add missing t.Parallel() in tests --------- Co-authored-by: TwiN --- client/client.go | 49 ++++++++++++++++++++++++------------------- client/client_test.go | 8 +++++++ go.mod | 3 ++- go.sum | 2 ++ 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/client/client.go b/client/client.go index b7abf175..825cab42 100644 --- a/client/client.go +++ b/client/client.go @@ -21,13 +21,13 @@ import ( "github.com/TwiN/gocache/v2" "github.com/TwiN/logr" "github.com/TwiN/whois" + "github.com/gorilla/websocket" "github.com/ishidawataru/sctp" "github.com/miekg/dns" ping "github.com/prometheus-community/pro-bing" "github.com/registrobr/rdap" "github.com/registrobr/rdap/protocol" "golang.org/x/crypto/ssh" - "golang.org/x/net/websocket" ) const ( @@ -394,48 +394,53 @@ func ShouldRunPingerAsPrivileged() bool { // QueryWebSocket opens a websocket connection, write `body` and return a message from the server func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) { const ( - Origin = "http://localhost/" - MaximumMessageSize = 1024 // in bytes + Origin = "http://localhost/" ) - wsConfig, err := websocket.NewConfig(address, Origin) - if err != nil { - return false, nil, fmt.Errorf("error configuring websocket connection: %w", err) - } - if headers != nil { - if wsConfig.Header == nil { - wsConfig.Header = make(http.Header) - } - for name, value := range headers { - wsConfig.Header.Set(name, value) + var ( + dialer = websocket.Dialer{ + EnableCompression: true, } + wsHeaders = make(http.Header) + ) + + wsHeaders.Set("Origin", Origin) + for name, value := range headers { + wsHeaders.Set(name, value) } + + ctx := context.Background() if config != nil { - wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout} - wsConfig.TlsConfig = &tls.Config{ + if config.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, config.Timeout) + defer cancel() + } + dialer.TLSClientConfig = &tls.Config{ InsecureSkipVerify: config.Insecure, } if config.HasTLSConfig() && config.TLS.isValid() == nil { - wsConfig.TlsConfig = configureTLS(wsConfig.TlsConfig, *config.TLS) + dialer.TLSClientConfig = configureTLS(dialer.TLSClientConfig, *config.TLS) } } // Dial URL - ws, err := websocket.DialConfig(wsConfig) + ws, _, err := dialer.DialContext(ctx, address, wsHeaders) if err != nil { return false, nil, fmt.Errorf("error dialing websocket: %w", err) } defer ws.Close() body = parseLocalAddressPlaceholder(body, ws.LocalAddr()) // Write message - if _, err := ws.Write([]byte(body)); err != nil { + if err := ws.WriteMessage(websocket.TextMessage, []byte(body)); err != nil { return false, nil, fmt.Errorf("error writing websocket body: %w", err) } // Read message - var n int - msg := make([]byte, MaximumMessageSize) - if n, err = ws.Read(msg); err != nil { + msgType, msg, err := ws.ReadMessage() + if err != nil { return false, nil, fmt.Errorf("error reading websocket message: %w", err) + } else if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { + return false, nil, fmt.Errorf("unexpected websocket message type: %d, expected %d or %d", msgType, websocket.TextMessage, websocket.BinaryMessage) } - return true, msg[:n], nil + return true, msg, nil } func QueryDNS(queryType, queryName, url string) (connected bool, dnsRcode string, body []byte, err error) { diff --git a/client/client_test.go b/client/client_test.go index 447655a7..aac17183 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -17,6 +17,7 @@ import ( ) func TestGetHTTPClient(t *testing.T) { + t.Parallel() cfg := &Config{ Insecure: false, IgnoreRedirect: false, @@ -42,6 +43,7 @@ func TestGetHTTPClient(t *testing.T) { } func TestRdapQuery(t *testing.T) { + t.Parallel() if _, err := rdapQuery("1.1.1.1"); err == nil { t.Error("expected an error due to the invalid domain type") } @@ -288,6 +290,7 @@ func TestCanPerformTLS(t *testing.T) { } func TestCanCreateConnection(t *testing.T) { + t.Parallel() connected, _ := CanCreateNetworkConnection("tcp", "127.0.0.1", "", &Config{Timeout: 5 * time.Second}) if connected { t.Error("should've failed, because there's no port in the address") @@ -302,6 +305,7 @@ func TestCanCreateConnection(t *testing.T) { // performs a Client Credentials OAuth2 flow and adds the obtained token as a `Authorization` // header to all outgoing HTTP calls. func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) { + t.Parallel() defer InjectHTTPClient(nil) oAuth2Config := &OAuth2Config{ ClientID: "00000000-0000-0000-0000-000000000000", @@ -357,6 +361,7 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) { } func TestQueryWebSocket(t *testing.T) { + t.Parallel() _, _, err := QueryWebSocket("", "body", nil, &Config{Timeout: 2 * time.Second}) if err == nil { t.Error("expected an error due to the address being invalid") @@ -368,6 +373,7 @@ func TestQueryWebSocket(t *testing.T) { } func TestTlsRenegotiation(t *testing.T) { + t.Parallel() scenarios := []struct { name string cfg TLSConfig @@ -411,6 +417,7 @@ func TestTlsRenegotiation(t *testing.T) { } func TestQueryDNS(t *testing.T) { + t.Parallel() scenarios := []struct { name string inputDNS dns.Config @@ -540,6 +547,7 @@ func TestQueryDNS(t *testing.T) { } func TestCheckSSHBanner(t *testing.T) { + t.Parallel() cfg := &Config{Timeout: 3} t.Run("no-auth-ssh", func(t *testing.T) { connected, status, err := CheckSSHBanner("tty.sdf.org", cfg) diff --git a/go.mod b/go.mod index 11767c03..cecb36f7 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/gofiber/fiber/v2 v2.52.9 github.com/google/go-github/v48 v48.2.0 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/ishidawataru/sctp v0.0.0-20230406120618-7ff4192f6ff2 github.com/lib/pq v1.10.9 github.com/miekg/dns v1.1.68 @@ -29,7 +30,6 @@ require ( github.com/valyala/fasthttp v1.67.0 github.com/wcharczuk/go-chart/v2 v2.1.2 golang.org/x/crypto v0.45.0 - golang.org/x/net v0.47.0 golang.org/x/oauth2 v0.32.0 golang.org/x/sync v0.18.0 google.golang.org/api v0.252.0 @@ -93,6 +93,7 @@ require ( golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.29.0 // indirect + golang.org/x/net v0.47.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/tools v0.38.0 // indirect diff --git a/go.sum b/go.sum index af08c1d3..dbe49d0d 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/ishidawataru/sctp v0.0.0-20230406120618-7ff4192f6ff2 h1:i2fYnDurfLlJH8AyyMOnkLHnHeP8Ff/DDpuZA/D3bPo=