diff --git a/src/config.go b/src/config.go index f0f2623..bacc0ee 100644 --- a/src/config.go +++ b/src/config.go @@ -24,6 +24,7 @@ type Device struct { User string `yaml:"user"` Type string `yaml:"type,omitempty"` Commands []string `yaml:"commands,omitempty"` + Address string `yaml:"address,omitempty"` } func readYAMLFile(path string) (map[string]interface{}, error) { diff --git a/src/config_test.go b/src/config_test.go index ee6849a..b5e8fd6 100644 --- a/src/config_test.go +++ b/src/config_test.go @@ -319,3 +319,44 @@ func TestConfigReadComplexMerge(t *testing.T) { t.Errorf("Expected 1 custom command for lab-switch, got %d", len(labDevice.Commands)) } } + +func TestConfigReadAddress(t *testing.T) { + tempDir := t.TempDir() + + // Create config file with address field + configPath := filepath.Join(tempDir, "address-config.yaml") + configContent := `devices: + router-with-address: + user: testuser + address: 192.168.1.100 + router-without-address: + user: testuser` + + err := os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to create test config file: %v", err) + } + + cfg, err := ConfigRead([]string{configPath}) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Test device with address + deviceWithAddress, exists := cfg.Devices["router-with-address"] + if !exists { + t.Error("Expected 'router-with-address' to exist") + } + if deviceWithAddress.Address != "192.168.1.100" { + t.Errorf("Expected address '192.168.1.100', got '%s'", deviceWithAddress.Address) + } + + // Test device without address (should be empty) + deviceWithoutAddress, exists := cfg.Devices["router-without-address"] + if !exists { + t.Error("Expected 'router-without-address' to exist") + } + if deviceWithoutAddress.Address != "" { + t.Errorf("Expected empty address, got '%s'", deviceWithoutAddress.Address) + } +} diff --git a/src/main.go b/src/main.go index 982a08f..98eac17 100644 --- a/src/main.go +++ b/src/main.go @@ -100,7 +100,7 @@ func main() { } // Create backup instance - backup := NewRouterBackup(hostname, user, password, keyFile, port) + backup := NewRouterBackup(hostname, deviceConfig.Address, user, password, keyFile, port) // Connect and backup if err := backup.Connect(); err != nil { diff --git a/src/ssh.go b/src/ssh.go index dfc3d67..44b1225 100644 --- a/src/ssh.go +++ b/src/ssh.go @@ -20,6 +20,7 @@ import ( // RouterBackup handles SSH connections and command execution type RouterBackup struct { hostname string + address string username string password string keyFile string @@ -28,9 +29,10 @@ type RouterBackup struct { } // NewRouterBackup creates a new RouterBackup instance -func NewRouterBackup(hostname, username, password, keyFile string, port int) *RouterBackup { +func NewRouterBackup(hostname, address, username, password, keyFile string, port int) *RouterBackup { return &RouterBackup{ hostname: hostname, + address: address, username: username, password: password, keyFile: keyFile, @@ -38,12 +40,32 @@ func NewRouterBackup(hostname, username, password, keyFile string, port int) *Ro } } +// isIPv6 checks if the given address is an IPv6 address +func isIPv6(address string) bool { + ip := net.ParseIP(address) + return ip != nil && ip.To4() == nil +} + +// getNetworkType determines the appropriate network type based on the target address +func getNetworkType(address string) string { + if isIPv6(address) { + return "tcp6" + } + return "tcp4" +} + // Connect establishes SSH connection to the router func (rb *RouterBackup) Connect() error { - // Get SSH config values for this host - hostname := ssh_config.Get(rb.hostname, "Hostname") - if hostname == "" { - hostname = rb.hostname + // Determine the target address - use explicit address if provided, otherwise use hostname + var targetHost string + if rb.address != "" { + targetHost = rb.address + } else { + // Get SSH config values for this host + targetHost = ssh_config.Get(rb.hostname, "Hostname") + if targetHost == "" { + targetHost = rb.hostname + } } portStr := ssh_config.Get(rb.hostname, "Port") @@ -144,14 +166,21 @@ func (rb *RouterBackup) Connect() error { return fmt.Errorf("no authentication method available") } - address := fmt.Sprintf("%s:%d", hostname, port) - client, err := ssh.Dial("tcp4", address, config) + // Format address properly for IPv6 + var address string + if isIPv6(targetHost) { + address = fmt.Sprintf("[%s]:%d", targetHost, port) + } else { + address = fmt.Sprintf("%s:%d", targetHost, port) + } + networkType := getNetworkType(targetHost) + client, err := ssh.Dial(networkType, address, config) if err != nil { - return fmt.Errorf("failed to connect to %s: %v", hostname, err) + return fmt.Errorf("failed to connect to %s: %v", targetHost, err) } rb.client = client - fmt.Printf("Successfully connected to %s\n", hostname) + fmt.Printf("Successfully connected to %s\n", targetHost) return nil } @@ -215,10 +244,10 @@ func (rb *RouterBackup) BackupCommands(commands []string, outputDir string) erro successCount++ } + fmt.Printf("Summary: %d/%d commands successful\n", successCount, len(commands)) if successCount > 0 { fmt.Printf("Output saved to %s\n", filepath) } - fmt.Printf("Summary: %d/%d commands successful\n", successCount, len(commands)) return nil } diff --git a/src/ssh_test.go b/src/ssh_test.go index 01a5c4b..b6c9446 100644 --- a/src/ssh_test.go +++ b/src/ssh_test.go @@ -9,7 +9,7 @@ import ( ) func TestNewRouterBackup(t *testing.T) { - rb := NewRouterBackup("testhost", "testuser", "testpass", "/path/to/key", 2222) + rb := NewRouterBackup("testhost", "", "testuser", "testpass", "/path/to/key", 2222) if rb.hostname != "testhost" { t.Errorf("Expected hostname 'testhost', got '%s'", rb.hostname) @@ -37,7 +37,7 @@ func TestNewRouterBackup(t *testing.T) { } func TestRunCommandWithoutConnection(t *testing.T) { - rb := NewRouterBackup("testhost", "testuser", "testpass", "", 22) + rb := NewRouterBackup("testhost", "", "testuser", "testpass", "", 22) _, err := rb.RunCommand("show version") if err == nil { @@ -53,7 +53,7 @@ func TestBackupCommandsDirectoryCreation(t *testing.T) { tempDir := t.TempDir() outputDir := filepath.Join(tempDir, "nonexistent", "backup") - rb := NewRouterBackup("testhost", "testuser", "testpass", "", 22) + rb := NewRouterBackup("testhost", "", "testuser", "testpass", "", 22) // This should create the directory even without a connection // and fail gracefully when trying to run commands @@ -74,7 +74,7 @@ func TestBackupCommandsDirectoryCreation(t *testing.T) { func TestBackupCommandsEmptyCommands(t *testing.T) { tempDir := t.TempDir() - rb := NewRouterBackup("testhost", "testuser", "testpass", "", 22) + rb := NewRouterBackup("testhost", "", "testuser", "testpass", "", 22) err := rb.BackupCommands([]string{}, tempDir) if err != nil { @@ -89,7 +89,7 @@ func TestBackupCommandsEmptyCommands(t *testing.T) { } func TestDisconnectWithoutConnection(t *testing.T) { - rb := NewRouterBackup("testhost", "testuser", "testpass", "", 22) + rb := NewRouterBackup("testhost", "", "testuser", "testpass", "", 22) // Should not panic or error when disconnecting without connection rb.Disconnect() @@ -155,7 +155,7 @@ func TestFindDefaultSSHKeyHomeError(t *testing.T) { func TestBackupCommandsFileOperations(t *testing.T) { tempDir := t.TempDir() - rb := NewRouterBackup("testhost", "testuser", "testpass", "", 22) + rb := NewRouterBackup("testhost", "", "testuser", "testpass", "", 22) // Create some fake commands (they will fail but we can test file operations) commands := []string{"show version", "show interfaces"} @@ -177,7 +177,7 @@ func TestBackupCommandsFileOperations(t *testing.T) { } func TestRouterBackupConnectionState(t *testing.T) { - rb := NewRouterBackup("testhost", "testuser", "testpass", "", 22) + rb := NewRouterBackup("testhost", "", "testuser", "testpass", "", 22) // Initially no client if rb.client != nil { @@ -190,3 +190,94 @@ func TestRouterBackupConnectionState(t *testing.T) { t.Error("Expected client to remain nil after disconnect") } } + +func TestNewRouterBackupWithAddress(t *testing.T) { + rb := NewRouterBackup("testhost", "192.168.1.100", "testuser", "testpass", "/path/to/key", 2222) + + if rb.hostname != "testhost" { + t.Errorf("Expected hostname 'testhost', got '%s'", rb.hostname) + } + + if rb.address != "192.168.1.100" { + t.Errorf("Expected address '192.168.1.100', got '%s'", rb.address) + } +} + +func TestIsIPv6(t *testing.T) { + // Test IPv4 addresses + if isIPv6("192.168.1.1") { + t.Error("Expected '192.168.1.1' to be detected as IPv4, not IPv6") + } + + if isIPv6("10.0.0.1") { + t.Error("Expected '10.0.0.1' to be detected as IPv4, not IPv6") + } + + // Test IPv6 addresses + if !isIPv6("2001:678:d78:500::") { + t.Error("Expected '2001:678:d78:500::' to be detected as IPv6") + } + + if !isIPv6("::1") { + t.Error("Expected '::1' to be detected as IPv6") + } + + if !isIPv6("fe80::1") { + t.Error("Expected 'fe80::1' to be detected as IPv6") + } + + // Test invalid addresses + if isIPv6("invalid.address") { + t.Error("Expected 'invalid.address' to not be detected as IPv6") + } + + if isIPv6("hostname.example.com") { + t.Error("Expected 'hostname.example.com' to not be detected as IPv6") + } +} + +func TestGetNetworkType(t *testing.T) { + // Test IPv4 addresses + if getNetworkType("192.168.1.1") != "tcp4" { + t.Errorf("Expected 'tcp4' for IPv4 address, got '%s'", getNetworkType("192.168.1.1")) + } + + // Test IPv6 addresses + if getNetworkType("2001:678:d78:500::") != "tcp6" { + t.Errorf("Expected 'tcp6' for IPv6 address, got '%s'", getNetworkType("2001:678:d78:500::")) + } + + if getNetworkType("::1") != "tcp6" { + t.Errorf("Expected 'tcp6' for IPv6 address, got '%s'", getNetworkType("::1")) + } + + // Test hostnames (should default to tcp4) + if getNetworkType("hostname.example.com") != "tcp4" { + t.Errorf("Expected 'tcp4' for hostname, got '%s'", getNetworkType("hostname.example.com")) + } +} + +func TestIPv6AddressFormatting(t *testing.T) { + // Test that we can create a RouterBackup with IPv6 address + // and that it stores the address correctly + rb := NewRouterBackup("testhost", "2001:678:d78:500::", "testuser", "testpass", "", 22) + + if !isIPv6(rb.address) { + t.Error("Expected IPv6 address to be detected as IPv6") + } + + if getNetworkType(rb.address) != "tcp6" { + t.Error("Expected IPv6 address to use tcp6 network type") + } + + // Test IPv4 for comparison + rb4 := NewRouterBackup("testhost", "192.168.1.1", "testuser", "testpass", "", 22) + + if isIPv6(rb4.address) { + t.Error("Expected IPv4 address to not be detected as IPv6") + } + + if getNetworkType(rb4.address) != "tcp4" { + t.Error("Expected IPv4 address to use tcp4 network type") + } +}