From c006b35871b2181518da467204547f21b5171f9e Mon Sep 17 00:00:00 2001 From: eleith Date: Thu, 11 Sep 2025 04:48:49 -0700 Subject: [PATCH] feat(client): starttls support for dns resolver (#1253) * customize starttls dialup connection if dnsresolver has a value, mirroring http client * add starttls connection test with a dns resolver --------- Co-authored-by: eleith --- client/client.go | 36 +++++++++++++++++++++++++++++++++--- client/client_test.go | 12 +++++++++++- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index bc8518e4..6aea4718 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -141,10 +142,39 @@ func CanPerformStartTLS(address string, config *Config) (connected bool, certifi if len(hostAndPort) != 2 { return false, nil, errors.New("invalid address for starttls, format must be host:port") } - connection, err := net.DialTimeout("tcp", address, config.Timeout) - if err != nil { - return + + var connection net.Conn + var dnsResolver *DNSResolverConfig + + if config.HasCustomDNSResolver() { + dnsResolver, err = config.parseDNSResolver() + + if err != nil { + // We're ignoring the error, because it should have been validated on startup ValidateAndSetDefaults. + // It shouldn't happen, but if it does, we'll log it... Better safe than sorry ;) + logr.Errorf("[client.getHTTPClient] THIS SHOULD NOT HAPPEN. Silently ignoring invalid DNS resolver due to error: %s", err.Error()) + } else { + dialer := &net.Dialer{ + Resolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, dnsResolver.Protocol, dnsResolver.Host+":"+dnsResolver.Port) + }, + }, + } + connection, err = dialer.DialContext(context.Background(), "tcp", address) + if err != nil { + return + } + } + } else { + connection, err = net.DialTimeout("tcp", address, config.Timeout) + if err != nil { + return + } } + smtpClient, err := smtp.NewClient(connection, hostAndPort[0]) if err != nil { return diff --git a/client/client_test.go b/client/client_test.go index 632a0ec2..7232147a 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -119,6 +119,7 @@ func TestCanPerformStartTLS(t *testing.T) { type args struct { address string insecure bool + dnsresolver string } tests := []struct { name string @@ -150,11 +151,20 @@ func TestCanPerformStartTLS(t *testing.T) { wantConnected: true, wantErr: false, }, + { + name: "dns resolver", + args: args{ + address: "smtp.gmail.com:587", + dnsresolver: "tcp://1.1.1.1:53", + }, + wantConnected: true, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - connected, _, err := CanPerformStartTLS(tt.args.address, &Config{Insecure: tt.args.insecure, Timeout: 5 * time.Second}) + connected, _, err := CanPerformStartTLS(tt.args.address, &Config{Insecure: tt.args.insecure, Timeout: 5 * time.Second, DNSResolver: tt.args.dnsresolver}) if (err != nil) != tt.wantErr { t.Errorf("CanPerformStartTLS() err=%v, wantErr=%v", err, tt.wantErr) return