Move to yaml.v3 and mergo. Refactor config parsing into a package. Refactor SSH connections into a package. Create default YAML directory, and update docs

This commit is contained in:
Pim van Pelt
2025-07-06 17:11:22 +02:00
parent 75646856aa
commit 769d9eb6cd
11 changed files with 441 additions and 490 deletions

73
src/config/config.go Normal file
View File

@ -0,0 +1,73 @@
package config
import (
"fmt"
"io/ioutil"
"dario.cat/mergo"
"gopkg.in/yaml.v3"
)
// Config structures
type Config struct {
Types map[string]DeviceType `yaml:"types"`
Devices map[string]Device `yaml:"devices"`
}
type DeviceType struct {
Commands []string `yaml:"commands"`
}
type Device struct {
User string `yaml:"user"`
Type string `yaml:"type,omitempty"`
Commands []string `yaml:"commands,omitempty"`
}
func readYAMLFile(path string) (map[string]interface{}, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
var result map[string]interface{}
if err := yaml.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
// ConfigRead loads and merges multiple YAML files into a single config object
func ConfigRead(yamlFiles []string) (*Config, error) {
var finalConfig map[string]interface{}
for _, file := range yamlFiles {
current, err := readYAMLFile(file)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %v", file, err)
}
if finalConfig == nil {
finalConfig = current
} else {
err := mergo.Merge(&finalConfig, current, mergo.WithOverride)
if err != nil {
return nil, fmt.Errorf("failed to merge %s: %v", file, err)
}
}
}
// Convert back to structured config
out, err := yaml.Marshal(finalConfig)
if err != nil {
return nil, fmt.Errorf("failed to marshal merged config: %v", err)
}
var config Config
if err := yaml.Unmarshal(out, &config); err != nil {
return nil, fmt.Errorf("failed to unmarshal to Config struct: %v", err)
}
return &config, nil
}

View File

@ -3,14 +3,15 @@ module router_backup
go 1.21
require (
dario.cat/mergo v1.0.2
github.com/kevinburke/ssh_config v1.2.0
github.com/spf13/cobra v1.8.0
golang.org/x/crypto v0.18.0
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sys v0.16.0 // indirect
)

View File

@ -1,3 +1,5 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@ -16,6 +18,5 @@ golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -4,362 +4,25 @@ package main
import (
"fmt"
"io/ioutil"
"log"
"net"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"router_backup/config"
"github.com/kevinburke/ssh_config"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"gopkg.in/yaml.v2"
)
const Version = "1.0.2"
// Config structures
type Config struct {
Types map[string]DeviceType `yaml:"types"`
Devices map[string]Device `yaml:"devices"`
}
// Config and SSH types are now in separate packages
type DeviceType struct {
Commands []string `yaml:"commands"`
}
// SSH connection methods are now in ssh.go
type Device struct {
User string `yaml:"user"`
Type string `yaml:"type,omitempty"`
Commands []string `yaml:"commands,omitempty"`
}
// YAML processing is now handled by the config package
// RouterBackup handles SSH connections and command execution
type RouterBackup struct {
hostname string
username string
password string
keyFile string
port int
client *ssh.Client
}
// NewRouterBackup creates a new RouterBackup instance
func NewRouterBackup(hostname, username, password, keyFile string, port int) *RouterBackup {
return &RouterBackup{
hostname: hostname,
username: username,
password: password,
keyFile: keyFile,
port: port,
}
}
// Connect establishes SSH connection to the router
func (rb *RouterBackup) Connect() error {
// Get SSH config values for this host
hostname := ssh_config.Get(rb.hostname, "Hostname")
if hostname == "" {
hostname = rb.hostname
}
portStr := ssh_config.Get(rb.hostname, "Port")
port := rb.port
if portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
username := ssh_config.Get(rb.hostname, "User")
if rb.username != "" {
username = rb.username
}
keyFile := ssh_config.Get(rb.hostname, "IdentityFile")
if rb.keyFile != "" {
keyFile = rb.keyFile
}
config := &ssh.ClientConfig{
User: username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 30 * time.Second,
}
// Apply SSH config crypto settings with compatibility filtering
if kexAlgorithms := ssh_config.Get(rb.hostname, "KexAlgorithms"); kexAlgorithms != "" && !strings.HasPrefix(kexAlgorithms, "+") {
// Only apply if it's an explicit list, not a +append
algorithms := strings.Split(kexAlgorithms, ",")
var finalAlgorithms []string
for _, alg := range algorithms {
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
}
config.KeyExchanges = finalAlgorithms
}
// Note: Cipher overrides disabled - Go SSH library defaults work
// if ciphers := ssh_config.Get(rb.hostname, "Ciphers"); ciphers != "" {
// config.Ciphers = ...
// }
if macs := ssh_config.Get(rb.hostname, "MACs"); macs != "" {
macList := strings.Split(macs, ",")
for i, mac := range macList {
macList[i] = strings.TrimSpace(mac)
}
config.MACs = macList
}
if hostKeyAlgorithms := ssh_config.Get(rb.hostname, "HostKeyAlgorithms"); hostKeyAlgorithms != "" && !strings.HasPrefix(hostKeyAlgorithms, "+") {
// Only apply if it's an explicit list, not a +append
algorithms := strings.Split(hostKeyAlgorithms, ",")
var finalAlgorithms []string
for _, alg := range algorithms {
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
}
config.HostKeyAlgorithms = finalAlgorithms
}
// Try SSH agent first if available
if sshAuthSock := os.Getenv("SSH_AUTH_SOCK"); sshAuthSock != "" {
if conn, err := net.Dial("unix", sshAuthSock); err == nil {
agentClient := agent.NewClient(conn)
config.Auth = []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}
}
}
// If SSH agent didn't work, try key file
if len(config.Auth) == 0 && keyFile != "" {
// Expand ~ in keyFile path
if strings.HasPrefix(keyFile, "~/") {
homeDir, err := os.UserHomeDir()
if err == nil {
keyFile = filepath.Join(homeDir, keyFile[2:])
}
}
key, err := ioutil.ReadFile(keyFile)
if err != nil {
return fmt.Errorf("unable to read private key: %v", err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return fmt.Errorf("unable to parse private key: %v", err)
}
config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
}
// Fall back to password if available
if len(config.Auth) == 0 && rb.password != "" {
config.Auth = []ssh.AuthMethod{ssh.Password(rb.password)}
}
if len(config.Auth) == 0 {
return fmt.Errorf("no authentication method available")
}
address := fmt.Sprintf("%s:%d", hostname, port)
client, err := ssh.Dial("tcp4", address, config)
if err != nil {
return fmt.Errorf("failed to connect to %s: %v", hostname, err)
}
rb.client = client
fmt.Printf("Successfully connected to %s\n", hostname)
return nil
}
// RunCommand executes a command on the router and returns the output
func (rb *RouterBackup) RunCommand(command string) (string, error) {
if rb.client == nil {
return "", fmt.Errorf("no active connection")
}
session, err := rb.client.NewSession()
if err != nil {
return "", fmt.Errorf("failed to create session: %v", err)
}
defer session.Close()
output, err := session.CombinedOutput(command)
if err != nil {
return "", fmt.Errorf("failed to execute command '%s': %v", command, err)
}
return string(output), nil
}
// BackupCommands runs multiple commands and saves outputs to files
func (rb *RouterBackup) BackupCommands(commands []string, outputDir string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %v", outputDir, err)
}
filename := rb.hostname
filepath := filepath.Join(outputDir, filename)
// Truncate file at start
file, err := os.Create(filepath)
if err != nil {
return fmt.Errorf("failed to create file %s: %v", filepath, err)
}
file.Close()
successCount := 0
for i, command := range commands {
fmt.Printf("Running command %d/%d: %s\n", i+1, len(commands), command)
output, err := rb.RunCommand(command)
if err != nil {
fmt.Printf("Error executing '%s': %v\n", command, err)
continue
}
// Append to file
file, err := os.OpenFile(filepath, os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
fmt.Printf("Failed to open file for writing: %v\n", err)
continue
}
fmt.Fprintf(file, "## COMMAND: %s\n", command)
file.WriteString(output)
file.Close()
fmt.Printf("Output saved to %s\n", filepath)
successCount++
}
fmt.Printf("Summary: %d/%d commands successful\n", successCount, len(commands))
return nil
}
// Disconnect closes SSH connection
func (rb *RouterBackup) Disconnect() {
if rb.client != nil {
rb.client.Close()
fmt.Printf("Disconnected from %s\n", rb.hostname)
}
}
// loadConfig loads the YAML configuration file with !include support
func loadConfig(configPath string) (*Config, error) {
processedYAML, err := processIncludes(configPath)
if err != nil {
return nil, fmt.Errorf("failed to process includes: %v", err)
}
var config Config
err = yaml.Unmarshal([]byte(processedYAML), &config)
if err != nil {
return nil, fmt.Errorf("failed to parse YAML: %v", err)
}
return &config, nil
}
// processIncludes processes YAML files with !include directives (one level deep)
func processIncludes(filePath string) (string, error) {
// Read the file
data, err := ioutil.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("failed to read file %s: %v", filePath, err)
}
content := string(data)
// Process !include directives
// Match patterns like: !include path/to/file.yaml (excluding commented lines)
includeRegex := regexp.MustCompile(`(?m)^(\s*)!include\s+(.+)$`)
baseDir := filepath.Dir(filePath)
// Process includes line by line to avoid conflicts
lines := strings.Split(content, "\n")
var resultLines []string
for _, line := range lines {
// Check if this line matches our include pattern
if match := includeRegex.FindStringSubmatch(line); match != nil {
leadingWhitespace := match[1]
includePath := strings.TrimSpace(match[2])
// Skip commented lines
if strings.Contains(strings.TrimSpace(line), "#") && strings.Index(strings.TrimSpace(line), "#") < strings.Index(strings.TrimSpace(line), "!include") {
resultLines = append(resultLines, line)
continue
}
// Remove quotes if present
includePath = strings.Trim(includePath, "\"'")
// Make path relative to current config file
if !filepath.IsAbs(includePath) {
includePath = filepath.Join(baseDir, includePath)
}
// Read the included file
includedData, err := ioutil.ReadFile(includePath)
if err != nil {
return "", fmt.Errorf("failed to read include file %s: %v", includePath, err)
}
// Use the captured leading whitespace as indentation prefix
indentPrefix := leadingWhitespace
// Indent each line of included content to match the !include line's indentation
includedLines := strings.Split(string(includedData), "\n")
for _, includeLine := range includedLines {
if strings.TrimSpace(includeLine) == "" {
resultLines = append(resultLines, "")
} else {
resultLines = append(resultLines, indentPrefix+includeLine)
}
}
} else {
// Regular line, just copy it
resultLines = append(resultLines, line)
}
}
content = strings.Join(resultLines, "\n")
return content, nil
}
// findDefaultSSHKey looks for default SSH keys
func findDefaultSSHKey() string {
homeDir, err := os.UserHomeDir()
if err != nil {
return ""
}
defaultKeys := []string{
filepath.Join(homeDir, ".ssh", "id_rsa"),
filepath.Join(homeDir, ".ssh", "id_ed25519"),
filepath.Join(homeDir, ".ssh", "id_ecdsa"),
}
for _, keyPath := range defaultKeys {
if _, err := os.Stat(keyPath); err == nil {
fmt.Printf("Using SSH key: %s\n", keyPath)
return keyPath
}
}
return ""
}
// SSH helper functions are now in ssh.go
func main() {
var configPath string
var yamlFiles []string
var password string
var keyFile string
var port int
@ -375,7 +38,7 @@ func main() {
fmt.Printf("IPng Networks Router Backup v%s\n", Version)
// Load configuration
config, err := loadConfig(configPath)
cfg, err := config.ConfigRead(yamlFiles)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
@ -393,16 +56,16 @@ func main() {
}
// Process devices
if len(config.Devices) == 0 {
if len(cfg.Devices) == 0 {
log.Fatal("No devices found in config file")
}
// Filter devices if --host flags are provided
devicesToProcess := config.Devices
devicesToProcess := cfg.Devices
if len(hostFilter) > 0 {
devicesToProcess = make(map[string]Device)
devicesToProcess = make(map[string]config.Device)
for _, hostname := range hostFilter {
if deviceConfig, exists := config.Devices[hostname]; exists {
if deviceConfig, exists := cfg.Devices[hostname]; exists {
devicesToProcess[hostname] = deviceConfig
} else {
fmt.Printf("Warning: Host '%s' not found in config file\n", hostname)
@ -422,7 +85,7 @@ func main() {
// If device has a type, get commands from types section
if deviceType != "" {
if typeConfig, exists := config.Types[deviceType]; exists {
if typeConfig, exists := cfg.Types[deviceType]; exists {
commands = typeConfig.Commands
}
}
@ -461,14 +124,14 @@ func main() {
},
}
rootCmd.Flags().StringVar(&configPath, "config", "", "YAML configuration file path (required)")
rootCmd.Flags().StringSliceVar(&yamlFiles, "yaml", []string{}, "YAML configuration file paths (required, can be repeated)")
rootCmd.Flags().StringVar(&password, "password", "", "SSH password")
rootCmd.Flags().StringVar(&keyFile, "key-file", "", "SSH private key file path")
rootCmd.Flags().IntVar(&port, "port", 22, "SSH port")
rootCmd.Flags().StringVar(&outputDir, "output-dir", "/tmp", "Output directory for command output files")
rootCmd.Flags().StringSliceVar(&hostFilter, "host", []string{}, "Specific host(s) to process (can be repeated, processes all if not specified)")
rootCmd.MarkFlagRequired("config")
rootCmd.MarkFlagRequired("yaml")
if err := rootCmd.Execute(); err != nil {
log.Fatal(err)

250
src/ssh.go Normal file
View File

@ -0,0 +1,250 @@
package main
import (
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/kevinburke/ssh_config"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// RouterBackup handles SSH connections and command execution
type RouterBackup struct {
hostname string
username string
password string
keyFile string
port int
client *ssh.Client
}
// NewRouterBackup creates a new RouterBackup instance
func NewRouterBackup(hostname, username, password, keyFile string, port int) *RouterBackup {
return &RouterBackup{
hostname: hostname,
username: username,
password: password,
keyFile: keyFile,
port: port,
}
}
// Connect establishes SSH connection to the router
func (rb *RouterBackup) Connect() error {
// Get SSH config values for this host
hostname := ssh_config.Get(rb.hostname, "Hostname")
if hostname == "" {
hostname = rb.hostname
}
portStr := ssh_config.Get(rb.hostname, "Port")
port := rb.port
if portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
username := ssh_config.Get(rb.hostname, "User")
if rb.username != "" {
username = rb.username
}
keyFile := ssh_config.Get(rb.hostname, "IdentityFile")
if rb.keyFile != "" {
keyFile = rb.keyFile
}
config := &ssh.ClientConfig{
User: username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 30 * time.Second,
}
// Apply SSH config crypto settings with compatibility filtering
if kexAlgorithms := ssh_config.Get(rb.hostname, "KexAlgorithms"); kexAlgorithms != "" && !strings.HasPrefix(kexAlgorithms, "+") {
// Only apply if it's an explicit list, not a +append
algorithms := strings.Split(kexAlgorithms, ",")
var finalAlgorithms []string
for _, alg := range algorithms {
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
}
config.KeyExchanges = finalAlgorithms
}
// Note: Cipher overrides disabled - Go SSH library defaults work better
// if ciphers := ssh_config.Get(rb.hostname, "Ciphers"); ciphers != "" {
// config.Ciphers = ...
// }
if macs := ssh_config.Get(rb.hostname, "MACs"); macs != "" {
macList := strings.Split(macs, ",")
for i, mac := range macList {
macList[i] = strings.TrimSpace(mac)
}
config.MACs = macList
}
if hostKeyAlgorithms := ssh_config.Get(rb.hostname, "HostKeyAlgorithms"); hostKeyAlgorithms != "" && !strings.HasPrefix(hostKeyAlgorithms, "+") {
// Only apply if it's an explicit list, not a +append
algorithms := strings.Split(hostKeyAlgorithms, ",")
var finalAlgorithms []string
for _, alg := range algorithms {
finalAlgorithms = append(finalAlgorithms, strings.TrimSpace(alg))
}
config.HostKeyAlgorithms = finalAlgorithms
}
// Try SSH agent first if available
if sshAuthSock := os.Getenv("SSH_AUTH_SOCK"); sshAuthSock != "" {
if conn, err := net.Dial("unix", sshAuthSock); err == nil {
agentClient := agent.NewClient(conn)
config.Auth = []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}
}
}
// If SSH agent didn't work, try key file
if len(config.Auth) == 0 && keyFile != "" {
// Expand ~ in keyFile path
if strings.HasPrefix(keyFile, "~/") {
homeDir, err := os.UserHomeDir()
if err == nil {
keyFile = filepath.Join(homeDir, keyFile[2:])
}
}
key, err := ioutil.ReadFile(keyFile)
if err != nil {
return fmt.Errorf("unable to read private key: %v", err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return fmt.Errorf("unable to parse private key: %v", err)
}
config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
}
// Fall back to password if available
if len(config.Auth) == 0 && rb.password != "" {
config.Auth = []ssh.AuthMethod{ssh.Password(rb.password)}
}
if len(config.Auth) == 0 {
return fmt.Errorf("no authentication method available")
}
address := fmt.Sprintf("%s:%d", hostname, port)
client, err := ssh.Dial("tcp4", address, config)
if err != nil {
return fmt.Errorf("failed to connect to %s: %v", hostname, err)
}
rb.client = client
fmt.Printf("Successfully connected to %s\n", hostname)
return nil
}
// RunCommand executes a command on the router and returns the output
func (rb *RouterBackup) RunCommand(command string) (string, error) {
if rb.client == nil {
return "", fmt.Errorf("no active connection")
}
session, err := rb.client.NewSession()
if err != nil {
return "", fmt.Errorf("failed to create session: %v", err)
}
defer session.Close()
output, err := session.CombinedOutput(command)
if err != nil {
return "", fmt.Errorf("failed to execute command '%s': %v", command, err)
}
return string(output), nil
}
// BackupCommands runs multiple commands and saves outputs to files
func (rb *RouterBackup) BackupCommands(commands []string, outputDir string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %v", outputDir, err)
}
filename := rb.hostname
filepath := filepath.Join(outputDir, filename)
// Truncate file at start
file, err := os.Create(filepath)
if err != nil {
return fmt.Errorf("failed to create file %s: %v", filepath, err)
}
file.Close()
successCount := 0
for i, command := range commands {
fmt.Printf("Running command %d/%d: %s\n", i+1, len(commands), command)
output, err := rb.RunCommand(command)
if err != nil {
fmt.Printf("Error executing '%s': %v\n", command, err)
continue
}
// Append to file
file, err := os.OpenFile(filepath, os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
fmt.Printf("Failed to open file for writing: %v\n", err)
continue
}
fmt.Fprintf(file, "## COMMAND: %s\n", command)
file.WriteString(output)
file.Close()
fmt.Printf("Output saved to %s\n", filepath)
successCount++
}
fmt.Printf("Summary: %d/%d commands successful\n", successCount, len(commands))
return nil
}
// Disconnect closes SSH connection
func (rb *RouterBackup) Disconnect() {
if rb.client != nil {
rb.client.Close()
fmt.Printf("Disconnected from %s\n", rb.hostname)
}
}
// findDefaultSSHKey looks for default SSH keys
func findDefaultSSHKey() string {
homeDir, err := os.UserHomeDir()
if err != nil {
return ""
}
defaultKeys := []string{
filepath.Join(homeDir, ".ssh", "id_rsa"),
filepath.Join(homeDir, ".ssh", "id_ed25519"),
filepath.Join(homeDir, ".ssh", "id_ecdsa"),
}
for _, keyPath := range defaultKeys {
if _, err := os.Stat(keyPath); err == nil {
fmt.Printf("Using SSH key: %s\n", keyPath)
return keyPath
}
}
return ""
}