Move to yaml.v3 and mergo. Refactor config parsing into a package. Refactor SSH connections into a package. Create default YAML directory, and update docs
This commit is contained in:
73
src/config/config.go
Normal file
73
src/config/config.go
Normal file
@ -0,0 +1,73 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config structures
|
||||
type Config struct {
|
||||
Types map[string]DeviceType `yaml:"types"`
|
||||
Devices map[string]Device `yaml:"devices"`
|
||||
}
|
||||
|
||||
type DeviceType struct {
|
||||
Commands []string `yaml:"commands"`
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
User string `yaml:"user"`
|
||||
Type string `yaml:"type,omitempty"`
|
||||
Commands []string `yaml:"commands,omitempty"`
|
||||
}
|
||||
|
||||
func readYAMLFile(path string) (map[string]interface{}, error) {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := yaml.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ConfigRead loads and merges multiple YAML files into a single config object
|
||||
func ConfigRead(yamlFiles []string) (*Config, error) {
|
||||
var finalConfig map[string]interface{}
|
||||
|
||||
for _, file := range yamlFiles {
|
||||
current, err := readYAMLFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %v", file, err)
|
||||
}
|
||||
|
||||
if finalConfig == nil {
|
||||
finalConfig = current
|
||||
} else {
|
||||
err := mergo.Merge(&finalConfig, current, mergo.WithOverride)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to merge %s: %v", file, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert back to structured config
|
||||
out, err := yaml.Marshal(finalConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal merged config: %v", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(out, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to Config struct: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
@ -3,14 +3,15 @@ module router_backup
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2
|
||||
github.com/kevinburke/ssh_config v1.2.0
|
||||
github.com/spf13/cobra v1.8.0
|
||||
golang.org/x/crypto v0.18.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
golang.org/x/sys v0.16.0 // indirect
|
||||
)
|
||||
|
@ -1,3 +1,5 @@
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
@ -16,6 +18,5 @@ golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
|
||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
365
src/main.go
365
src/main.go
@ -4,362 +4,25 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"router_backup/config"
|
||||
|
||||
"github.com/kevinburke/ssh_config"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const Version = "1.0.2"
|
||||
|
||||
// Config structures
|
||||
type Config struct {
|
||||
Types map[string]DeviceType `yaml:"types"`
|
||||
Devices map[string]Device `yaml:"devices"`
|
||||
}
|
||||
// Config and SSH types are now in separate packages
|
||||
|
||||
type DeviceType struct {
|
||||
Commands []string `yaml:"commands"`
|
||||
}
|
||||
// SSH connection methods are now in ssh.go
|
||||
|
||||
type Device struct {
|
||||
User string `yaml:"user"`
|
||||
Type string `yaml:"type,omitempty"`
|
||||
Commands []string `yaml:"commands,omitempty"`
|
||||
}
|
||||
// YAML processing is now handled by the config package
|
||||
|
||||
// RouterBackup handles SSH connections and command execution
|
||||
type RouterBackup struct {
|
||||
hostname string
|
||||
username string
|
||||
password string
|
||||
keyFile string
|
||||
port int
|
||||
client *ssh.Client
|
||||
}
|
||||
|
||||
// NewRouterBackup creates a new RouterBackup instance
|
||||
func NewRouterBackup(hostname, username, password, keyFile string, port int) *RouterBackup {
|
||||
return &RouterBackup{
|
||||
hostname: hostname,
|
||||
username: username,
|
||||
password: password,
|
||||
keyFile: keyFile,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes SSH connection to the router
|
||||
func (rb *RouterBackup) Connect() error {
|
||||
// Get SSH config values for this host
|
||||
hostname := ssh_config.Get(rb.hostname, "Hostname")
|
||||
if hostname == "" {
|
||||
hostname = rb.hostname
|
||||
}
|
||||
|
||||
portStr := ssh_config.Get(rb.hostname, "Port")
|
||||
port := rb.port
|
||||
if portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
username := ssh_config.Get(rb.hostname, "User")
|
||||
if rb.username != "" {
|
||||
username = rb.username
|
||||
}
|
||||
|
||||
keyFile := ssh_config.Get(rb.hostname, "IdentityFile")
|
||||
if rb.keyFile != "" {
|
||||
keyFile = rb.keyFile
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Apply SSH config crypto settings with compatibility filtering
|
||||
if kexAlgorithms := ssh_config.Get(rb.hostname, "KexAlgorithms"); kexAlgorithms != "" && !strings.HasPrefix(kexAlgorithms, "+") {
|
||||
// Only apply if it's an explicit list, not a +append
|
||||
algorithms := strings.Split(kexAlgorithms, ",")
|
||||
var finalAlgorithms []string
|
||||
for _, alg := range algorithms {
|
||||
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
||||
}
|
||||
config.KeyExchanges = finalAlgorithms
|
||||
}
|
||||
|
||||
// Note: Cipher overrides disabled - Go SSH library defaults work
|
||||
// if ciphers := ssh_config.Get(rb.hostname, "Ciphers"); ciphers != "" {
|
||||
// config.Ciphers = ...
|
||||
// }
|
||||
|
||||
if macs := ssh_config.Get(rb.hostname, "MACs"); macs != "" {
|
||||
macList := strings.Split(macs, ",")
|
||||
for i, mac := range macList {
|
||||
macList[i] = strings.TrimSpace(mac)
|
||||
}
|
||||
config.MACs = macList
|
||||
}
|
||||
|
||||
if hostKeyAlgorithms := ssh_config.Get(rb.hostname, "HostKeyAlgorithms"); hostKeyAlgorithms != "" && !strings.HasPrefix(hostKeyAlgorithms, "+") {
|
||||
// Only apply if it's an explicit list, not a +append
|
||||
algorithms := strings.Split(hostKeyAlgorithms, ",")
|
||||
var finalAlgorithms []string
|
||||
for _, alg := range algorithms {
|
||||
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
||||
}
|
||||
config.HostKeyAlgorithms = finalAlgorithms
|
||||
}
|
||||
|
||||
// Try SSH agent first if available
|
||||
if sshAuthSock := os.Getenv("SSH_AUTH_SOCK"); sshAuthSock != "" {
|
||||
if conn, err := net.Dial("unix", sshAuthSock); err == nil {
|
||||
agentClient := agent.NewClient(conn)
|
||||
config.Auth = []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}
|
||||
}
|
||||
}
|
||||
|
||||
// If SSH agent didn't work, try key file
|
||||
if len(config.Auth) == 0 && keyFile != "" {
|
||||
// Expand ~ in keyFile path
|
||||
if strings.HasPrefix(keyFile, "~/") {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
keyFile = filepath.Join(homeDir, keyFile[2:])
|
||||
}
|
||||
}
|
||||
|
||||
key, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read private key: %v", err)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse private key: %v", err)
|
||||
}
|
||||
|
||||
config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
|
||||
}
|
||||
|
||||
// Fall back to password if available
|
||||
if len(config.Auth) == 0 && rb.password != "" {
|
||||
config.Auth = []ssh.AuthMethod{ssh.Password(rb.password)}
|
||||
}
|
||||
|
||||
if len(config.Auth) == 0 {
|
||||
return fmt.Errorf("no authentication method available")
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
client, err := ssh.Dial("tcp4", address, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to %s: %v", hostname, err)
|
||||
}
|
||||
|
||||
rb.client = client
|
||||
fmt.Printf("Successfully connected to %s\n", hostname)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunCommand executes a command on the router and returns the output
|
||||
func (rb *RouterBackup) RunCommand(command string) (string, error) {
|
||||
if rb.client == nil {
|
||||
return "", fmt.Errorf("no active connection")
|
||||
}
|
||||
|
||||
session, err := rb.client.NewSession()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput(command)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute command '%s': %v", command, err)
|
||||
}
|
||||
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
// BackupCommands runs multiple commands and saves outputs to files
|
||||
func (rb *RouterBackup) BackupCommands(commands []string, outputDir string) error {
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %v", outputDir, err)
|
||||
}
|
||||
|
||||
filename := rb.hostname
|
||||
filepath := filepath.Join(outputDir, filename)
|
||||
|
||||
// Truncate file at start
|
||||
file, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file %s: %v", filepath, err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
successCount := 0
|
||||
for i, command := range commands {
|
||||
fmt.Printf("Running command %d/%d: %s\n", i+1, len(commands), command)
|
||||
output, err := rb.RunCommand(command)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Error executing '%s': %v\n", command, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to file
|
||||
file, err := os.OpenFile(filepath, os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to open file for writing: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(file, "## COMMAND: %s\n", command)
|
||||
file.WriteString(output)
|
||||
file.Close()
|
||||
|
||||
fmt.Printf("Output saved to %s\n", filepath)
|
||||
successCount++
|
||||
}
|
||||
|
||||
fmt.Printf("Summary: %d/%d commands successful\n", successCount, len(commands))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect closes SSH connection
|
||||
func (rb *RouterBackup) Disconnect() {
|
||||
if rb.client != nil {
|
||||
rb.client.Close()
|
||||
fmt.Printf("Disconnected from %s\n", rb.hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// loadConfig loads the YAML configuration file with !include support
|
||||
func loadConfig(configPath string) (*Config, error) {
|
||||
processedYAML, err := processIncludes(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process includes: %v", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
err = yaml.Unmarshal([]byte(processedYAML), &config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// processIncludes processes YAML files with !include directives (one level deep)
|
||||
func processIncludes(filePath string) (string, error) {
|
||||
// Read the file
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
|
||||
// Process !include directives
|
||||
// Match patterns like: !include path/to/file.yaml (excluding commented lines)
|
||||
includeRegex := regexp.MustCompile(`(?m)^(\s*)!include\s+(.+)$`)
|
||||
|
||||
baseDir := filepath.Dir(filePath)
|
||||
|
||||
// Process includes line by line to avoid conflicts
|
||||
lines := strings.Split(content, "\n")
|
||||
var resultLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
// Check if this line matches our include pattern
|
||||
if match := includeRegex.FindStringSubmatch(line); match != nil {
|
||||
leadingWhitespace := match[1]
|
||||
includePath := strings.TrimSpace(match[2])
|
||||
|
||||
// Skip commented lines
|
||||
if strings.Contains(strings.TrimSpace(line), "#") && strings.Index(strings.TrimSpace(line), "#") < strings.Index(strings.TrimSpace(line), "!include") {
|
||||
resultLines = append(resultLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove quotes if present
|
||||
includePath = strings.Trim(includePath, "\"'")
|
||||
|
||||
// Make path relative to current config file
|
||||
if !filepath.IsAbs(includePath) {
|
||||
includePath = filepath.Join(baseDir, includePath)
|
||||
}
|
||||
|
||||
// Read the included file
|
||||
includedData, err := ioutil.ReadFile(includePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read include file %s: %v", includePath, err)
|
||||
}
|
||||
|
||||
// Use the captured leading whitespace as indentation prefix
|
||||
indentPrefix := leadingWhitespace
|
||||
|
||||
// Indent each line of included content to match the !include line's indentation
|
||||
includedLines := strings.Split(string(includedData), "\n")
|
||||
for _, includeLine := range includedLines {
|
||||
if strings.TrimSpace(includeLine) == "" {
|
||||
resultLines = append(resultLines, "")
|
||||
} else {
|
||||
resultLines = append(resultLines, indentPrefix+includeLine)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular line, just copy it
|
||||
resultLines = append(resultLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
content = strings.Join(resultLines, "\n")
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// findDefaultSSHKey looks for default SSH keys
|
||||
func findDefaultSSHKey() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
defaultKeys := []string{
|
||||
filepath.Join(homeDir, ".ssh", "id_rsa"),
|
||||
filepath.Join(homeDir, ".ssh", "id_ed25519"),
|
||||
filepath.Join(homeDir, ".ssh", "id_ecdsa"),
|
||||
}
|
||||
|
||||
for _, keyPath := range defaultKeys {
|
||||
if _, err := os.Stat(keyPath); err == nil {
|
||||
fmt.Printf("Using SSH key: %s\n", keyPath)
|
||||
return keyPath
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
// SSH helper functions are now in ssh.go
|
||||
|
||||
func main() {
|
||||
var configPath string
|
||||
var yamlFiles []string
|
||||
var password string
|
||||
var keyFile string
|
||||
var port int
|
||||
@ -375,7 +38,7 @@ func main() {
|
||||
fmt.Printf("IPng Networks Router Backup v%s\n", Version)
|
||||
|
||||
// Load configuration
|
||||
config, err := loadConfig(configPath)
|
||||
cfg, err := config.ConfigRead(yamlFiles)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
@ -393,16 +56,16 @@ func main() {
|
||||
}
|
||||
|
||||
// Process devices
|
||||
if len(config.Devices) == 0 {
|
||||
if len(cfg.Devices) == 0 {
|
||||
log.Fatal("No devices found in config file")
|
||||
}
|
||||
|
||||
// Filter devices if --host flags are provided
|
||||
devicesToProcess := config.Devices
|
||||
devicesToProcess := cfg.Devices
|
||||
if len(hostFilter) > 0 {
|
||||
devicesToProcess = make(map[string]Device)
|
||||
devicesToProcess = make(map[string]config.Device)
|
||||
for _, hostname := range hostFilter {
|
||||
if deviceConfig, exists := config.Devices[hostname]; exists {
|
||||
if deviceConfig, exists := cfg.Devices[hostname]; exists {
|
||||
devicesToProcess[hostname] = deviceConfig
|
||||
} else {
|
||||
fmt.Printf("Warning: Host '%s' not found in config file\n", hostname)
|
||||
@ -422,7 +85,7 @@ func main() {
|
||||
|
||||
// If device has a type, get commands from types section
|
||||
if deviceType != "" {
|
||||
if typeConfig, exists := config.Types[deviceType]; exists {
|
||||
if typeConfig, exists := cfg.Types[deviceType]; exists {
|
||||
commands = typeConfig.Commands
|
||||
}
|
||||
}
|
||||
@ -461,14 +124,14 @@ func main() {
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.Flags().StringVar(&configPath, "config", "", "YAML configuration file path (required)")
|
||||
rootCmd.Flags().StringSliceVar(&yamlFiles, "yaml", []string{}, "YAML configuration file paths (required, can be repeated)")
|
||||
rootCmd.Flags().StringVar(&password, "password", "", "SSH password")
|
||||
rootCmd.Flags().StringVar(&keyFile, "key-file", "", "SSH private key file path")
|
||||
rootCmd.Flags().IntVar(&port, "port", 22, "SSH port")
|
||||
rootCmd.Flags().StringVar(&outputDir, "output-dir", "/tmp", "Output directory for command output files")
|
||||
rootCmd.Flags().StringSliceVar(&hostFilter, "host", []string{}, "Specific host(s) to process (can be repeated, processes all if not specified)")
|
||||
|
||||
rootCmd.MarkFlagRequired("config")
|
||||
rootCmd.MarkFlagRequired("yaml")
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
log.Fatal(err)
|
||||
|
250
src/ssh.go
Normal file
250
src/ssh.go
Normal file
@ -0,0 +1,250 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kevinburke/ssh_config"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
// RouterBackup handles SSH connections and command execution
|
||||
type RouterBackup struct {
|
||||
hostname string
|
||||
username string
|
||||
password string
|
||||
keyFile string
|
||||
port int
|
||||
client *ssh.Client
|
||||
}
|
||||
|
||||
// NewRouterBackup creates a new RouterBackup instance
|
||||
func NewRouterBackup(hostname, username, password, keyFile string, port int) *RouterBackup {
|
||||
return &RouterBackup{
|
||||
hostname: hostname,
|
||||
username: username,
|
||||
password: password,
|
||||
keyFile: keyFile,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes SSH connection to the router
|
||||
func (rb *RouterBackup) Connect() error {
|
||||
// Get SSH config values for this host
|
||||
hostname := ssh_config.Get(rb.hostname, "Hostname")
|
||||
if hostname == "" {
|
||||
hostname = rb.hostname
|
||||
}
|
||||
|
||||
portStr := ssh_config.Get(rb.hostname, "Port")
|
||||
port := rb.port
|
||||
if portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
username := ssh_config.Get(rb.hostname, "User")
|
||||
if rb.username != "" {
|
||||
username = rb.username
|
||||
}
|
||||
|
||||
keyFile := ssh_config.Get(rb.hostname, "IdentityFile")
|
||||
if rb.keyFile != "" {
|
||||
keyFile = rb.keyFile
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Apply SSH config crypto settings with compatibility filtering
|
||||
if kexAlgorithms := ssh_config.Get(rb.hostname, "KexAlgorithms"); kexAlgorithms != "" && !strings.HasPrefix(kexAlgorithms, "+") {
|
||||
// Only apply if it's an explicit list, not a +append
|
||||
algorithms := strings.Split(kexAlgorithms, ",")
|
||||
var finalAlgorithms []string
|
||||
for _, alg := range algorithms {
|
||||
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
||||
}
|
||||
config.KeyExchanges = finalAlgorithms
|
||||
}
|
||||
|
||||
// Note: Cipher overrides disabled - Go SSH library defaults work better
|
||||
// if ciphers := ssh_config.Get(rb.hostname, "Ciphers"); ciphers != "" {
|
||||
// config.Ciphers = ...
|
||||
// }
|
||||
|
||||
if macs := ssh_config.Get(rb.hostname, "MACs"); macs != "" {
|
||||
macList := strings.Split(macs, ",")
|
||||
for i, mac := range macList {
|
||||
macList[i] = strings.TrimSpace(mac)
|
||||
}
|
||||
config.MACs = macList
|
||||
}
|
||||
|
||||
if hostKeyAlgorithms := ssh_config.Get(rb.hostname, "HostKeyAlgorithms"); hostKeyAlgorithms != "" && !strings.HasPrefix(hostKeyAlgorithms, "+") {
|
||||
// Only apply if it's an explicit list, not a +append
|
||||
algorithms := strings.Split(hostKeyAlgorithms, ",")
|
||||
var finalAlgorithms []string
|
||||
for _, alg := range algorithms {
|
||||
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
|
||||
}
|
||||
config.HostKeyAlgorithms = finalAlgorithms
|
||||
}
|
||||
|
||||
// Try SSH agent first if available
|
||||
if sshAuthSock := os.Getenv("SSH_AUTH_SOCK"); sshAuthSock != "" {
|
||||
if conn, err := net.Dial("unix", sshAuthSock); err == nil {
|
||||
agentClient := agent.NewClient(conn)
|
||||
config.Auth = []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}
|
||||
}
|
||||
}
|
||||
|
||||
// If SSH agent didn't work, try key file
|
||||
if len(config.Auth) == 0 && keyFile != "" {
|
||||
// Expand ~ in keyFile path
|
||||
if strings.HasPrefix(keyFile, "~/") {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
keyFile = filepath.Join(homeDir, keyFile[2:])
|
||||
}
|
||||
}
|
||||
|
||||
key, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read private key: %v", err)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse private key: %v", err)
|
||||
}
|
||||
|
||||
config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
|
||||
}
|
||||
|
||||
// Fall back to password if available
|
||||
if len(config.Auth) == 0 && rb.password != "" {
|
||||
config.Auth = []ssh.AuthMethod{ssh.Password(rb.password)}
|
||||
}
|
||||
|
||||
if len(config.Auth) == 0 {
|
||||
return fmt.Errorf("no authentication method available")
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
client, err := ssh.Dial("tcp4", address, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to %s: %v", hostname, err)
|
||||
}
|
||||
|
||||
rb.client = client
|
||||
fmt.Printf("Successfully connected to %s\n", hostname)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunCommand executes a command on the router and returns the output
|
||||
func (rb *RouterBackup) RunCommand(command string) (string, error) {
|
||||
if rb.client == nil {
|
||||
return "", fmt.Errorf("no active connection")
|
||||
}
|
||||
|
||||
session, err := rb.client.NewSession()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput(command)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute command '%s': %v", command, err)
|
||||
}
|
||||
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
// BackupCommands runs multiple commands and saves outputs to files
|
||||
func (rb *RouterBackup) BackupCommands(commands []string, outputDir string) error {
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %v", outputDir, err)
|
||||
}
|
||||
|
||||
filename := rb.hostname
|
||||
filepath := filepath.Join(outputDir, filename)
|
||||
|
||||
// Truncate file at start
|
||||
file, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file %s: %v", filepath, err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
successCount := 0
|
||||
for i, command := range commands {
|
||||
fmt.Printf("Running command %d/%d: %s\n", i+1, len(commands), command)
|
||||
output, err := rb.RunCommand(command)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Error executing '%s': %v\n", command, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to file
|
||||
file, err := os.OpenFile(filepath, os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to open file for writing: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(file, "## COMMAND: %s\n", command)
|
||||
file.WriteString(output)
|
||||
file.Close()
|
||||
|
||||
fmt.Printf("Output saved to %s\n", filepath)
|
||||
successCount++
|
||||
}
|
||||
|
||||
fmt.Printf("Summary: %d/%d commands successful\n", successCount, len(commands))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect closes SSH connection
|
||||
func (rb *RouterBackup) Disconnect() {
|
||||
if rb.client != nil {
|
||||
rb.client.Close()
|
||||
fmt.Printf("Disconnected from %s\n", rb.hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// findDefaultSSHKey looks for default SSH keys
|
||||
func findDefaultSSHKey() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
defaultKeys := []string{
|
||||
filepath.Join(homeDir, ".ssh", "id_rsa"),
|
||||
filepath.Join(homeDir, ".ssh", "id_ed25519"),
|
||||
filepath.Join(homeDir, ".ssh", "id_ecdsa"),
|
||||
}
|
||||
|
||||
for _, keyPath := range defaultKeys {
|
||||
if _, err := os.Stat(keyPath); err == nil {
|
||||
fmt.Printf("Using SSH key: %s\n", keyPath)
|
||||
return keyPath
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
Reference in New Issue
Block a user