326 lines
8.5 KiB
Go
326 lines
8.5 KiB
Go
// Copyright 2025, IPng Networks GmbH, Pim van Pelt <pim@ipng.ch>
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/kevinburke/ssh_config"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
)
|
|
|
|
// RouterBackup handles SSH connections and command execution
|
|
type RouterBackup struct {
|
|
hostname string
|
|
address string
|
|
username string
|
|
password string
|
|
keyFile string
|
|
port int
|
|
client *ssh.Client
|
|
}
|
|
|
|
// NewRouterBackup creates a new RouterBackup instance
|
|
func NewRouterBackup(hostname, address, username, password, keyFile string, port int) *RouterBackup {
|
|
return &RouterBackup{
|
|
hostname: hostname,
|
|
address: address,
|
|
username: username,
|
|
password: password,
|
|
keyFile: keyFile,
|
|
port: port,
|
|
}
|
|
}
|
|
|
|
// isIPv6 checks if the given address is an IPv6 address
|
|
func isIPv6(address string) bool {
|
|
ip := net.ParseIP(address)
|
|
return ip != nil && ip.To4() == nil
|
|
}
|
|
|
|
// getNetworkType determines the appropriate network type based on the target address
|
|
func getNetworkType(address string) string {
|
|
if isIPv6(address) {
|
|
return "tcp6"
|
|
}
|
|
return "tcp4"
|
|
}
|
|
|
|
// Connect establishes SSH connection to the router
|
|
func (rb *RouterBackup) Connect() error {
|
|
// Determine the target address - use explicit address if provided, otherwise use hostname
|
|
var targetHost string
|
|
if rb.address != "" {
|
|
targetHost = rb.address
|
|
} else {
|
|
// Get SSH config values for this host
|
|
targetHost = ssh_config.Get(rb.hostname, "Hostname")
|
|
if targetHost == "" {
|
|
targetHost = rb.hostname
|
|
}
|
|
}
|
|
|
|
portStr := ssh_config.Get(rb.hostname, "Port")
|
|
port := rb.port
|
|
if portStr != "" {
|
|
if p, err := strconv.Atoi(portStr); err == nil {
|
|
port = p
|
|
}
|
|
}
|
|
|
|
username := ssh_config.Get(rb.hostname, "User")
|
|
if rb.username != "" {
|
|
username = rb.username
|
|
}
|
|
|
|
keyFile := ssh_config.Get(rb.hostname, "IdentityFile")
|
|
if rb.keyFile != "" {
|
|
keyFile = rb.keyFile
|
|
}
|
|
|
|
config := &ssh.ClientConfig{
|
|
User: username,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
|
|
// Apply SSH config crypto settings with compatibility filtering
|
|
if kexAlgorithms := ssh_config.Get(rb.hostname, "KexAlgorithms"); kexAlgorithms != "" && !strings.HasPrefix(kexAlgorithms, "+") {
|
|
// Only apply if it's an explicit list, not a +append
|
|
algorithms := strings.Split(kexAlgorithms, ",")
|
|
var finalAlgorithms []string
|
|
for _, alg := range algorithms {
|
|
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
|
}
|
|
config.KeyExchanges = finalAlgorithms
|
|
}
|
|
|
|
// Note: Cipher overrides disabled - Go SSH library defaults work better
|
|
// if ciphers := ssh_config.Get(rb.hostname, "Ciphers"); ciphers != "" {
|
|
// config.Ciphers = ...
|
|
// }
|
|
|
|
if macs := ssh_config.Get(rb.hostname, "MACs"); macs != "" {
|
|
macList := strings.Split(macs, ",")
|
|
for i, mac := range macList {
|
|
macList[i] = strings.TrimSpace(mac)
|
|
}
|
|
config.MACs = macList
|
|
}
|
|
|
|
if hostKeyAlgorithms := ssh_config.Get(rb.hostname, "HostKeyAlgorithms"); hostKeyAlgorithms != "" && !strings.HasPrefix(hostKeyAlgorithms, "+") {
|
|
// Only apply if it's an explicit list, not a +append
|
|
algorithms := strings.Split(hostKeyAlgorithms, ",")
|
|
var finalAlgorithms []string
|
|
for _, alg := range algorithms {
|
|
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
|
}
|
|
config.HostKeyAlgorithms = finalAlgorithms
|
|
}
|
|
|
|
// Try SSH agent first if available
|
|
if sshAuthSock := os.Getenv("SSH_AUTH_SOCK"); sshAuthSock != "" {
|
|
if conn, err := net.Dial("unix", sshAuthSock); err == nil {
|
|
agentClient := agent.NewClient(conn)
|
|
config.Auth = []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}
|
|
}
|
|
}
|
|
|
|
// If SSH agent didn't work, try key file
|
|
if keyFile != "" {
|
|
// Expand ~ in keyFile path
|
|
if strings.HasPrefix(keyFile, "~/") {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err == nil {
|
|
keyFile = filepath.Join(homeDir, keyFile[2:])
|
|
}
|
|
}
|
|
|
|
key, err := ioutil.ReadFile(keyFile)
|
|
if err == nil {
|
|
signer, err := ssh.ParsePrivateKey(key)
|
|
if err != nil {
|
|
fmt.Errorf("unable to parse private key: %v", err)
|
|
} else {
|
|
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fall back to password if available
|
|
if rb.password != "" {
|
|
config.Auth = append(config.Auth, ssh.Password(rb.password))
|
|
}
|
|
|
|
if len(config.Auth) == 0 {
|
|
return fmt.Errorf("no authentication method available")
|
|
}
|
|
|
|
// Format address properly for IPv6
|
|
var address string
|
|
if isIPv6(targetHost) {
|
|
address = fmt.Sprintf("[%s]:%d", targetHost, port)
|
|
} else {
|
|
address = fmt.Sprintf("%s:%d", targetHost, port)
|
|
}
|
|
networkType := getNetworkType(targetHost)
|
|
client, err := ssh.Dial(networkType, address, config)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to %s: %v", targetHost, err)
|
|
}
|
|
|
|
rb.client = client
|
|
fmt.Printf("%s: Successfully connected to %s\n", rb.hostname, targetHost)
|
|
return nil
|
|
}
|
|
|
|
// RunCommand executes a command on the router and returns the output
|
|
func (rb *RouterBackup) RunCommand(command string) (string, error) {
|
|
if rb.client == nil {
|
|
return "", fmt.Errorf("no active connection")
|
|
}
|
|
|
|
session, err := rb.client.NewSession()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
output, err := session.CombinedOutput(command)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to execute command '%s': %v", command, err)
|
|
}
|
|
|
|
return string(output), nil
|
|
}
|
|
|
|
// filterOutput removes lines matching exclude patterns from the output
|
|
func filterOutput(output string, excludePatterns []string) string {
|
|
if len(excludePatterns) == 0 {
|
|
return output
|
|
}
|
|
|
|
lines := strings.Split(output, "\n")
|
|
var filteredLines []string
|
|
|
|
for _, line := range lines {
|
|
exclude := false
|
|
for _, pattern := range excludePatterns {
|
|
if matched, _ := regexp.MatchString(pattern, line); matched {
|
|
exclude = true
|
|
break
|
|
}
|
|
}
|
|
if !exclude {
|
|
filteredLines = append(filteredLines, line)
|
|
}
|
|
}
|
|
|
|
return strings.Join(filteredLines, "\n")
|
|
}
|
|
|
|
// BackupCommands runs multiple commands and saves outputs to files
|
|
func (rb *RouterBackup) BackupCommands(commands []string, excludePatterns []string, outputDir string) error {
|
|
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
|
return fmt.Errorf("failed to create directory %s: %v", outputDir, err)
|
|
}
|
|
|
|
filename := rb.hostname
|
|
finalPath := filepath.Join(outputDir, filename)
|
|
tempPath := finalPath + ".new"
|
|
|
|
// Create temporary file
|
|
file, err := os.Create(tempPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create temporary file %s: %v", tempPath, err)
|
|
}
|
|
file.Close()
|
|
|
|
successCount := 0
|
|
hasErrors := false
|
|
|
|
for i, command := range commands {
|
|
fmt.Printf("%s: Running command %d/%d: %s\n", rb.hostname, i+1, len(commands), command)
|
|
output, err := rb.RunCommand(command)
|
|
|
|
if err != nil {
|
|
fmt.Printf("%s: Error executing '%s': %v\n", rb.hostname, command, err)
|
|
hasErrors = true
|
|
continue
|
|
}
|
|
|
|
// Append to temporary file
|
|
file, err := os.OpenFile(tempPath, os.O_APPEND|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
fmt.Printf("%s: Failed to open file for writing: %v\n", rb.hostname, err)
|
|
hasErrors = true
|
|
continue
|
|
}
|
|
|
|
fmt.Fprintf(file, "## COMMAND: %s\n", command)
|
|
filteredOutput := filterOutput(output, excludePatterns)
|
|
file.WriteString(filteredOutput)
|
|
file.Close()
|
|
|
|
successCount++
|
|
}
|
|
|
|
fmt.Printf("%s: Summary: %d/%d commands successful\n", rb.hostname, successCount, len(commands))
|
|
|
|
if hasErrors || successCount == 0 {
|
|
// Remove .new suffix and log error
|
|
if err := os.Remove(tempPath); err != nil {
|
|
fmt.Printf("%s: Failed to remove temporary file %s: %v\n", rb.hostname, tempPath, err)
|
|
}
|
|
return fmt.Errorf("device backup incomplete due to command failures")
|
|
}
|
|
|
|
// All commands succeeded, move file into place atomically
|
|
if err := os.Rename(tempPath, finalPath); err != nil {
|
|
return fmt.Errorf("failed to move temporary file to final location: %v", err)
|
|
}
|
|
|
|
fmt.Printf("%s: Output saved to %s\n", rb.hostname, finalPath)
|
|
return nil
|
|
}
|
|
|
|
// Disconnect closes SSH connection
|
|
func (rb *RouterBackup) Disconnect() {
|
|
if rb.client != nil {
|
|
rb.client.Close()
|
|
fmt.Printf("%s: Disconnected\n", rb.hostname)
|
|
}
|
|
}
|
|
|
|
// findDefaultSSHKey looks for default SSH keys
|
|
func findDefaultSSHKey() string {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
defaultKeys := []string{
|
|
filepath.Join(homeDir, ".ssh", "id_rsa"),
|
|
filepath.Join(homeDir, ".ssh", "id_ed25519"),
|
|
filepath.Join(homeDir, ".ssh", "id_ecdsa"),
|
|
}
|
|
|
|
for _, keyPath := range defaultKeys {
|
|
if _, err := os.Stat(keyPath); err == nil {
|
|
// Key discovery logging moved to main.go for hostname context
|
|
return keyPath
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|