fix(client): Switch websocket library (#1423)
* fix(websocket): switch to gorilla/websocket * fix(client): add missing t.Parallel() in tests --------- Co-authored-by: TwiN <twin@linux.com>
This commit is contained in:
@@ -21,13 +21,13 @@ import (
|
|||||||
"github.com/TwiN/gocache/v2"
|
"github.com/TwiN/gocache/v2"
|
||||||
"github.com/TwiN/logr"
|
"github.com/TwiN/logr"
|
||||||
"github.com/TwiN/whois"
|
"github.com/TwiN/whois"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/ishidawataru/sctp"
|
"github.com/ishidawataru/sctp"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
ping "github.com/prometheus-community/pro-bing"
|
ping "github.com/prometheus-community/pro-bing"
|
||||||
"github.com/registrobr/rdap"
|
"github.com/registrobr/rdap"
|
||||||
"github.com/registrobr/rdap/protocol"
|
"github.com/registrobr/rdap/protocol"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/net/websocket"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -394,48 +394,53 @@ func ShouldRunPingerAsPrivileged() bool {
|
|||||||
// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
|
// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
|
||||||
func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) {
|
func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) {
|
||||||
const (
|
const (
|
||||||
Origin = "http://localhost/"
|
Origin = "http://localhost/"
|
||||||
MaximumMessageSize = 1024 // in bytes
|
|
||||||
)
|
)
|
||||||
wsConfig, err := websocket.NewConfig(address, Origin)
|
var (
|
||||||
if err != nil {
|
dialer = websocket.Dialer{
|
||||||
return false, nil, fmt.Errorf("error configuring websocket connection: %w", err)
|
EnableCompression: true,
|
||||||
}
|
|
||||||
if headers != nil {
|
|
||||||
if wsConfig.Header == nil {
|
|
||||||
wsConfig.Header = make(http.Header)
|
|
||||||
}
|
|
||||||
for name, value := range headers {
|
|
||||||
wsConfig.Header.Set(name, value)
|
|
||||||
}
|
}
|
||||||
|
wsHeaders = make(http.Header)
|
||||||
|
)
|
||||||
|
|
||||||
|
wsHeaders.Set("Origin", Origin)
|
||||||
|
for name, value := range headers {
|
||||||
|
wsHeaders.Set(name, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
if config != nil {
|
if config != nil {
|
||||||
wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout}
|
if config.Timeout > 0 {
|
||||||
wsConfig.TlsConfig = &tls.Config{
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, config.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig = &tls.Config{
|
||||||
InsecureSkipVerify: config.Insecure,
|
InsecureSkipVerify: config.Insecure,
|
||||||
}
|
}
|
||||||
if config.HasTLSConfig() && config.TLS.isValid() == nil {
|
if config.HasTLSConfig() && config.TLS.isValid() == nil {
|
||||||
wsConfig.TlsConfig = configureTLS(wsConfig.TlsConfig, *config.TLS)
|
dialer.TLSClientConfig = configureTLS(dialer.TLSClientConfig, *config.TLS)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Dial URL
|
// Dial URL
|
||||||
ws, err := websocket.DialConfig(wsConfig)
|
ws, _, err := dialer.DialContext(ctx, address, wsHeaders)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, fmt.Errorf("error dialing websocket: %w", err)
|
return false, nil, fmt.Errorf("error dialing websocket: %w", err)
|
||||||
}
|
}
|
||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
body = parseLocalAddressPlaceholder(body, ws.LocalAddr())
|
body = parseLocalAddressPlaceholder(body, ws.LocalAddr())
|
||||||
// Write message
|
// Write message
|
||||||
if _, err := ws.Write([]byte(body)); err != nil {
|
if err := ws.WriteMessage(websocket.TextMessage, []byte(body)); err != nil {
|
||||||
return false, nil, fmt.Errorf("error writing websocket body: %w", err)
|
return false, nil, fmt.Errorf("error writing websocket body: %w", err)
|
||||||
}
|
}
|
||||||
// Read message
|
// Read message
|
||||||
var n int
|
msgType, msg, err := ws.ReadMessage()
|
||||||
msg := make([]byte, MaximumMessageSize)
|
if err != nil {
|
||||||
if n, err = ws.Read(msg); err != nil {
|
|
||||||
return false, nil, fmt.Errorf("error reading websocket message: %w", err)
|
return false, nil, fmt.Errorf("error reading websocket message: %w", err)
|
||||||
|
} else if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||||
|
return false, nil, fmt.Errorf("unexpected websocket message type: %d, expected %d or %d", msgType, websocket.TextMessage, websocket.BinaryMessage)
|
||||||
}
|
}
|
||||||
return true, msg[:n], nil
|
return true, msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryDNS(queryType, queryName, url string) (connected bool, dnsRcode string, body []byte, err error) {
|
func QueryDNS(queryType, queryName, url string) (connected bool, dnsRcode string, body []byte, err error) {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestGetHTTPClient(t *testing.T) {
|
func TestGetHTTPClient(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
Insecure: false,
|
Insecure: false,
|
||||||
IgnoreRedirect: false,
|
IgnoreRedirect: false,
|
||||||
@@ -42,6 +43,7 @@ func TestGetHTTPClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRdapQuery(t *testing.T) {
|
func TestRdapQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
if _, err := rdapQuery("1.1.1.1"); err == nil {
|
if _, err := rdapQuery("1.1.1.1"); err == nil {
|
||||||
t.Error("expected an error due to the invalid domain type")
|
t.Error("expected an error due to the invalid domain type")
|
||||||
}
|
}
|
||||||
@@ -288,6 +290,7 @@ func TestCanPerformTLS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCanCreateConnection(t *testing.T) {
|
func TestCanCreateConnection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
connected, _ := CanCreateNetworkConnection("tcp", "127.0.0.1", "", &Config{Timeout: 5 * time.Second})
|
connected, _ := CanCreateNetworkConnection("tcp", "127.0.0.1", "", &Config{Timeout: 5 * time.Second})
|
||||||
if connected {
|
if connected {
|
||||||
t.Error("should've failed, because there's no port in the address")
|
t.Error("should've failed, because there's no port in the address")
|
||||||
@@ -302,6 +305,7 @@ func TestCanCreateConnection(t *testing.T) {
|
|||||||
// performs a Client Credentials OAuth2 flow and adds the obtained token as a `Authorization`
|
// performs a Client Credentials OAuth2 flow and adds the obtained token as a `Authorization`
|
||||||
// header to all outgoing HTTP calls.
|
// header to all outgoing HTTP calls.
|
||||||
func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
|
func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
defer InjectHTTPClient(nil)
|
defer InjectHTTPClient(nil)
|
||||||
oAuth2Config := &OAuth2Config{
|
oAuth2Config := &OAuth2Config{
|
||||||
ClientID: "00000000-0000-0000-0000-000000000000",
|
ClientID: "00000000-0000-0000-0000-000000000000",
|
||||||
@@ -357,6 +361,7 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryWebSocket(t *testing.T) {
|
func TestQueryWebSocket(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
_, _, err := QueryWebSocket("", "body", nil, &Config{Timeout: 2 * time.Second})
|
_, _, err := QueryWebSocket("", "body", nil, &Config{Timeout: 2 * time.Second})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected an error due to the address being invalid")
|
t.Error("expected an error due to the address being invalid")
|
||||||
@@ -368,6 +373,7 @@ func TestQueryWebSocket(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTlsRenegotiation(t *testing.T) {
|
func TestTlsRenegotiation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
scenarios := []struct {
|
scenarios := []struct {
|
||||||
name string
|
name string
|
||||||
cfg TLSConfig
|
cfg TLSConfig
|
||||||
@@ -411,6 +417,7 @@ func TestTlsRenegotiation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryDNS(t *testing.T) {
|
func TestQueryDNS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
scenarios := []struct {
|
scenarios := []struct {
|
||||||
name string
|
name string
|
||||||
inputDNS dns.Config
|
inputDNS dns.Config
|
||||||
@@ -540,6 +547,7 @@ func TestQueryDNS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckSSHBanner(t *testing.T) {
|
func TestCheckSSHBanner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
cfg := &Config{Timeout: 3}
|
cfg := &Config{Timeout: 3}
|
||||||
t.Run("no-auth-ssh", func(t *testing.T) {
|
t.Run("no-auth-ssh", func(t *testing.T) {
|
||||||
connected, status, err := CheckSSHBanner("tty.sdf.org", cfg)
|
connected, status, err := CheckSSHBanner("tty.sdf.org", cfg)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -20,6 +20,7 @@ require (
|
|||||||
github.com/gofiber/fiber/v2 v2.52.9
|
github.com/gofiber/fiber/v2 v2.52.9
|
||||||
github.com/google/go-github/v48 v48.2.0
|
github.com/google/go-github/v48 v48.2.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/ishidawataru/sctp v0.0.0-20230406120618-7ff4192f6ff2
|
github.com/ishidawataru/sctp v0.0.0-20230406120618-7ff4192f6ff2
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
github.com/miekg/dns v1.1.68
|
github.com/miekg/dns v1.1.68
|
||||||
@@ -29,7 +30,6 @@ require (
|
|||||||
github.com/valyala/fasthttp v1.67.0
|
github.com/valyala/fasthttp v1.67.0
|
||||||
github.com/wcharczuk/go-chart/v2 v2.1.2
|
github.com/wcharczuk/go-chart/v2 v2.1.2
|
||||||
golang.org/x/crypto v0.45.0
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/net v0.47.0
|
|
||||||
golang.org/x/oauth2 v0.32.0
|
golang.org/x/oauth2 v0.32.0
|
||||||
golang.org/x/sync v0.18.0
|
golang.org/x/sync v0.18.0
|
||||||
google.golang.org/api v0.252.0
|
google.golang.org/api v0.252.0
|
||||||
@@ -93,6 +93,7 @@ require (
|
|||||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
|
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
|
||||||
golang.org/x/image v0.18.0 // indirect
|
golang.org/x/image v0.18.0 // indirect
|
||||||
golang.org/x/mod v0.29.0 // indirect
|
golang.org/x/mod v0.29.0 // indirect
|
||||||
|
golang.org/x/net v0.47.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
golang.org/x/tools v0.38.0 // indirect
|
golang.org/x/tools v0.38.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -101,6 +101,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU
|
|||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
|
||||||
github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo=
|
github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo=
|
||||||
github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc=
|
github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
|
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
|
||||||
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||||
github.com/ishidawataru/sctp v0.0.0-20230406120618-7ff4192f6ff2 h1:i2fYnDurfLlJH8AyyMOnkLHnHeP8Ff/DDpuZA/D3bPo=
|
github.com/ishidawataru/sctp v0.0.0-20230406120618-7ff4192f6ff2 h1:i2fYnDurfLlJH8AyyMOnkLHnHeP8Ff/DDpuZA/D3bPo=
|
||||||
|
|||||||
Reference in New Issue
Block a user