Files
gatus/config/tunneling/sshtunnel/sshtunnel.go
TwiN 40345a03d3 feat(client): Add support for SSH tunneling (#1298)
* feat(client): Add support for SSH tunneling

* Fix test
2025-09-28 14:26:12 -04:00

158 lines
4.1 KiB
Go

package sshtunnel
import (
"fmt"
"net"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
// Config represents the configuration for an SSH tunnel
type Config struct {
Type string `yaml:"type"`
Host string `yaml:"host"`
Port int `yaml:"port,omitempty"`
Username string `yaml:"username"`
PrivateKey string `yaml:"private-key,omitempty"`
Password string `yaml:"password,omitempty"`
}
// ValidateAndSetDefaults validates the SSH tunnel configuration and sets defaults
func (c *Config) ValidateAndSetDefaults() error {
if c.Type != "SSH" {
return fmt.Errorf("unsupported tunnel type: %s", c.Type)
}
if c.Host == "" {
return fmt.Errorf("host is required")
}
if c.Username == "" {
return fmt.Errorf("username is required")
}
if c.PrivateKey == "" && c.Password == "" {
return fmt.Errorf("either private-key or password is required")
}
if c.Port == 0 {
c.Port = 22
}
return nil
}
// SSHTunnel represents an SSH tunnel connection
type SSHTunnel struct {
config *Config
mu sync.RWMutex
client *ssh.Client
// Cached authentication methods to avoid reparsing private keys
authMethods []ssh.AuthMethod
}
// New creates a new SSH tunnel with the given configuration
func New(config *Config) *SSHTunnel {
tunnel := &SSHTunnel{
config: config,
}
// Parse authentication methods once during initialization to avoid
// expensive cryptographic operations on every connection attempt
if config.PrivateKey != "" {
if signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey)); err == nil {
tunnel.authMethods = []ssh.AuthMethod{ssh.PublicKeys(signer)}
}
// Note: We don't return error here to maintain backward compatibility.
// Invalid keys will be caught during first connection attempt.
} else if config.Password != "" {
tunnel.authMethods = []ssh.AuthMethod{ssh.Password(config.Password)}
}
return tunnel
}
// Connect establishes the SSH connection
func (t *SSHTunnel) Connect() error {
t.mu.Lock()
defer t.mu.Unlock()
return t.connectUnsafe()
}
// connectUnsafe establishes the SSH connection without acquiring locks
// Must be called with t.mu.Lock() already held
func (t *SSHTunnel) connectUnsafe() error {
// Use cached authentication methods to avoid expensive crypto operations
if len(t.authMethods) == 0 {
return fmt.Errorf("no authentication method available")
}
config := &ssh.ClientConfig{
User: t.config.Username,
Timeout: 30 * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Skip host key verification
Auth: t.authMethods, // Use pre-parsed authentication
}
// Connect to SSH server
addr := fmt.Sprintf("%s:%d", t.config.Host, t.config.Port)
client, err := ssh.Dial("tcp", addr, config)
if err != nil {
return fmt.Errorf("SSH connection failed: %w", err)
}
t.client = client
return nil
}
// Close closes the SSH connection
func (t *SSHTunnel) Close() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.client != nil {
err := t.client.Close()
t.client = nil
return err
}
return nil
}
// Dial creates a connection through the SSH tunnel
func (t *SSHTunnel) Dial(network, addr string) (net.Conn, error) {
t.mu.RLock()
client := t.client
t.mu.RUnlock()
// Ensure we have an SSH connection
if client == nil {
// Use write lock to prevent race condition during connection
t.mu.Lock()
// Double-check client after acquiring lock
if t.client == nil {
if err := t.connectUnsafe(); err != nil {
t.mu.Unlock()
return nil, err
}
}
client = t.client
t.mu.Unlock()
}
// Create connection through SSH tunnel
conn, err := client.Dial(network, addr)
if err != nil {
// Close stale connection before retry to prevent leak
t.mu.Lock()
if t.client != nil {
t.client.Close()
t.client = nil
}
t.mu.Unlock()
// Retry once - connection might be stale
if connErr := t.Connect(); connErr != nil {
return nil, fmt.Errorf("SSH tunnel dial failed: %w (retry failed: %v)", err, connErr)
}
t.mu.RLock()
client = t.client
t.mu.RUnlock()
conn, err = client.Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("SSH tunnel dial failed after retry: %w", err)
}
}
return conn, nil
}