Add ssh_config parsing
This commit is contained in:
82
src/main.go
82
src/main.go
@ -10,9 +10,11 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kevinburke/ssh_config"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
@ -60,12 +62,70 @@ func NewRouterBackup(hostname, username, password, keyFile string, port int) *Ro
|
||||
|
||||
// Connect establishes SSH connection to the router
|
||||
func (rb *RouterBackup) Connect() error {
|
||||
// Get SSH config values for this host
|
||||
hostname := ssh_config.Get(rb.hostname, "Hostname")
|
||||
if hostname == "" {
|
||||
hostname = rb.hostname
|
||||
}
|
||||
|
||||
portStr := ssh_config.Get(rb.hostname, "Port")
|
||||
port := rb.port
|
||||
if portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
username := ssh_config.Get(rb.hostname, "User")
|
||||
if rb.username != "" {
|
||||
username = rb.username
|
||||
}
|
||||
|
||||
keyFile := ssh_config.Get(rb.hostname, "IdentityFile")
|
||||
if rb.keyFile != "" {
|
||||
keyFile = rb.keyFile
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: rb.username,
|
||||
User: username,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Apply SSH config crypto settings with compatibility filtering
|
||||
if kexAlgorithms := ssh_config.Get(rb.hostname, "KexAlgorithms"); kexAlgorithms != "" && !strings.HasPrefix(kexAlgorithms, "+") {
|
||||
// Only apply if it's an explicit list, not a +append
|
||||
algorithms := strings.Split(kexAlgorithms, ",")
|
||||
var finalAlgorithms []string
|
||||
for _, alg := range algorithms {
|
||||
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
||||
}
|
||||
config.KeyExchanges = finalAlgorithms
|
||||
}
|
||||
|
||||
// Note: Cipher overrides disabled - Go SSH library defaults work
|
||||
// if ciphers := ssh_config.Get(rb.hostname, "Ciphers"); ciphers != "" {
|
||||
// config.Ciphers = ...
|
||||
// }
|
||||
|
||||
if macs := ssh_config.Get(rb.hostname, "MACs"); macs != "" {
|
||||
macList := strings.Split(macs, ",")
|
||||
for i, mac := range macList {
|
||||
macList[i] = strings.TrimSpace(mac)
|
||||
}
|
||||
config.MACs = macList
|
||||
}
|
||||
|
||||
if hostKeyAlgorithms := ssh_config.Get(rb.hostname, "HostKeyAlgorithms"); hostKeyAlgorithms != "" && !strings.HasPrefix(hostKeyAlgorithms, "+") {
|
||||
// Only apply if it's an explicit list, not a +append
|
||||
algorithms := strings.Split(hostKeyAlgorithms, ",")
|
||||
var finalAlgorithms []string
|
||||
for _, alg := range algorithms {
|
||||
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
||||
}
|
||||
config.HostKeyAlgorithms = finalAlgorithms
|
||||
}
|
||||
|
||||
// Try SSH agent first if available
|
||||
if sshAuthSock := os.Getenv("SSH_AUTH_SOCK"); sshAuthSock != "" {
|
||||
if conn, err := net.Dial("unix", sshAuthSock); err == nil {
|
||||
@ -75,8 +135,16 @@ func (rb *RouterBackup) Connect() error {
|
||||
}
|
||||
|
||||
// If SSH agent didn't work, try key file
|
||||
if len(config.Auth) == 0 && rb.keyFile != "" {
|
||||
key, err := ioutil.ReadFile(rb.keyFile)
|
||||
if len(config.Auth) == 0 && keyFile != "" {
|
||||
// Expand ~ in keyFile path
|
||||
if strings.HasPrefix(keyFile, "~/") {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
keyFile = filepath.Join(homeDir, keyFile[2:])
|
||||
}
|
||||
}
|
||||
|
||||
key, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read private key: %v", err)
|
||||
}
|
||||
@ -98,14 +166,14 @@ func (rb *RouterBackup) Connect() error {
|
||||
return fmt.Errorf("no authentication method available")
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", rb.hostname, rb.port)
|
||||
client, err := ssh.Dial("tcp", address, config)
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
client, err := ssh.Dial("tcp4", address, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to %s: %v", rb.hostname, err)
|
||||
return fmt.Errorf("failed to connect to %s: %v", hostname, err)
|
||||
}
|
||||
|
||||
rb.client = client
|
||||
fmt.Printf("Successfully connected to %s\n", rb.hostname)
|
||||
fmt.Printf("Successfully connected to %s\n", hostname)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user