From 405c15f756242bffce63345962f640398749e377 Mon Sep 17 00:00:00 2001 From: TwiN Date: Tue, 30 Sep 2025 14:08:56 -0400 Subject: [PATCH] fix(tunneling): Add exponential backoff retry (#1303) --- config/tunneling/sshtunnel/sshtunnel.go | 49 ++++++++++++++----------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/config/tunneling/sshtunnel/sshtunnel.go b/config/tunneling/sshtunnel/sshtunnel.go index 1063585f..e0f0d258 100644 --- a/config/tunneling/sshtunnel/sshtunnel.go +++ b/config/tunneling/sshtunnel/sshtunnel.go @@ -54,7 +54,6 @@ func New(config *Config) *SSHTunnel { tunnel := &SSHTunnel{ config: config, } - // Parse authentication methods once during initialization to avoid // expensive cryptographic operations on every connection attempt if config.PrivateKey != "" { @@ -66,7 +65,6 @@ func New(config *Config) *SSHTunnel { } else if config.Password != "" { tunnel.authMethods = []ssh.AuthMethod{ssh.Password(config.Password)} } - return tunnel } @@ -131,27 +129,34 @@ func (t *SSHTunnel) Dial(network, addr string) (net.Conn, error) { client = t.client t.mu.Unlock() } - // Create connection through SSH tunnel - conn, err := client.Dial(network, addr) - if err != nil { - // Close stale connection before retry to prevent leak - t.mu.Lock() - if t.client != nil { - t.client.Close() - t.client = nil + // Attempt dial with exponential backoff retry + const maxRetries = 3 + const baseDelay = time.Second + var lastErr error + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + // Exponential backoff: 1s, 2s, 4s + delay := baseDelay << (attempt - 1) + time.Sleep(delay) + // Close stale connection and reconnect + t.mu.Lock() + if t.client != nil { + _ = t.client.Close() + t.client = nil + } + if err := t.connectUnsafe(); err != nil { + t.mu.Unlock() + lastErr = fmt.Errorf("reconnect attempt %d failed: %w", attempt, err) + continue + } + client = t.client + t.mu.Unlock() } - t.mu.Unlock() - // Retry once - connection might be stale - if connErr := t.Connect(); connErr != nil { - return nil, fmt.Errorf("SSH tunnel dial failed: %w (retry failed: %v)", err, connErr) - } - t.mu.RLock() - client = t.client - t.mu.RUnlock() - conn, err = client.Dial(network, addr) - if err != nil { - return nil, fmt.Errorf("SSH tunnel dial failed after retry: %w", err) + conn, err := client.Dial(network, addr) + if err == nil { + return conn, nil } + lastErr = err } - return conn, nil + return nil, fmt.Errorf("SSH tunnel dial failed after %d attempts: %w", maxRetries, lastErr) }