add yaml include feature
This commit is contained in:
79
src/main.go
79
src/main.go
@ -9,6 +9,8 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
@ -180,15 +182,15 @@ func (rb *RouterBackup) Disconnect() {
|
||||
}
|
||||
}
|
||||
|
||||
// loadConfig loads the YAML configuration file
|
||||
// loadConfig loads the YAML configuration file with !include support
|
||||
func loadConfig(configPath string) (*Config, error) {
|
||||
data, err := ioutil.ReadFile(configPath)
|
||||
processedYAML, err := processIncludes(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %v", configPath, err)
|
||||
return nil, fmt.Errorf("failed to process includes: %v", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
err = yaml.Unmarshal([]byte(processedYAML), &config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML: %v", err)
|
||||
}
|
||||
@ -196,6 +198,75 @@ func loadConfig(configPath string) (*Config, error) {
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// processIncludes processes YAML files with !include directives (one level deep)
|
||||
func processIncludes(filePath string) (string, error) {
|
||||
// Read the file
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
|
||||
// Process !include directives
|
||||
// Match patterns like: !include path/to/file.yaml (excluding commented lines)
|
||||
includeRegex := regexp.MustCompile(`(?m)^(\s*)!include\s+(.+)$`)
|
||||
|
||||
baseDir := filepath.Dir(filePath)
|
||||
|
||||
// Process includes line by line to avoid conflicts
|
||||
lines := strings.Split(content, "\n")
|
||||
var resultLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
// Check if this line matches our include pattern
|
||||
if match := includeRegex.FindStringSubmatch(line); match != nil {
|
||||
leadingWhitespace := match[1]
|
||||
includePath := strings.TrimSpace(match[2])
|
||||
|
||||
// Skip commented lines
|
||||
if strings.Contains(strings.TrimSpace(line), "#") && strings.Index(strings.TrimSpace(line), "#") < strings.Index(strings.TrimSpace(line), "!include") {
|
||||
resultLines = append(resultLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove quotes if present
|
||||
includePath = strings.Trim(includePath, "\"'")
|
||||
|
||||
// Make path relative to current config file
|
||||
if !filepath.IsAbs(includePath) {
|
||||
includePath = filepath.Join(baseDir, includePath)
|
||||
}
|
||||
|
||||
// Read the included file
|
||||
includedData, err := ioutil.ReadFile(includePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read include file %s: %v", includePath, err)
|
||||
}
|
||||
|
||||
// Use the captured leading whitespace as indentation prefix
|
||||
indentPrefix := leadingWhitespace
|
||||
|
||||
// Indent each line of included content to match the !include line's indentation
|
||||
includedLines := strings.Split(string(includedData), "\n")
|
||||
for _, includeLine := range includedLines {
|
||||
if strings.TrimSpace(includeLine) == "" {
|
||||
resultLines = append(resultLines, "")
|
||||
} else {
|
||||
resultLines = append(resultLines, indentPrefix+includeLine)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular line, just copy it
|
||||
resultLines = append(resultLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
content = strings.Join(resultLines, "\n")
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// findDefaultSSHKey looks for default SSH keys
|
||||
func findDefaultSSHKey() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
|
201
src/main_test.go
201
src/main_test.go
@ -5,6 +5,7 @@ package main
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -71,15 +72,25 @@ func TestFindDefaultSSHKeyNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
// Create a temporary directory and files
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "test-config.yaml")
|
||||
|
||||
// Create device-types.yaml file
|
||||
deviceTypesPath := filepath.Join(tempDir, "device-types.yaml")
|
||||
deviceTypesContent := `test-type:
|
||||
commands:
|
||||
- show version
|
||||
- show status`
|
||||
|
||||
err := os.WriteFile(deviceTypesPath, []byte(deviceTypesContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create device-types file: %v", err)
|
||||
}
|
||||
|
||||
// Create main config file with !include
|
||||
configPath := filepath.Join(tempDir, "test-config.yaml")
|
||||
configContent := `types:
|
||||
test-type:
|
||||
commands:
|
||||
- show version
|
||||
- show status
|
||||
!include device-types.yaml
|
||||
|
||||
devices:
|
||||
test-device:
|
||||
@ -91,7 +102,7 @@ devices:
|
||||
- direct command
|
||||
`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
err = os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
@ -285,3 +296,179 @@ func TestRouterBackupCreation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test !include functionality
|
||||
func TestProcessIncludes(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create included file
|
||||
includedPath := filepath.Join(tempDir, "included.yaml")
|
||||
includedContent := `test-type:
|
||||
commands:
|
||||
- show version
|
||||
- show status`
|
||||
|
||||
err := os.WriteFile(includedPath, []byte(includedContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create included file: %v", err)
|
||||
}
|
||||
|
||||
// Create main file with !include
|
||||
mainPath := filepath.Join(tempDir, "main.yaml")
|
||||
mainContent := `types:
|
||||
!include included.yaml
|
||||
devices:
|
||||
test-device:
|
||||
user: testuser
|
||||
type: test-type`
|
||||
|
||||
err = os.WriteFile(mainPath, []byte(mainContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main file: %v", err)
|
||||
}
|
||||
|
||||
// Process includes
|
||||
result, err := processIncludes(mainPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to process includes: %v", err)
|
||||
}
|
||||
|
||||
// Check that include was processed
|
||||
if !strings.Contains(result, "show version") {
|
||||
t.Error("Expected included content to be present in result")
|
||||
}
|
||||
if !strings.Contains(result, "show status") {
|
||||
t.Error("Expected included content to be present in result")
|
||||
}
|
||||
if strings.Contains(result, "!include") {
|
||||
t.Error("Expected !include directive to be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessIncludesWithQuotes(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create included file with spaces in name
|
||||
includedPath := filepath.Join(tempDir, "file with spaces.yaml")
|
||||
includedContent := `production-srlinux:
|
||||
commands:
|
||||
- show version`
|
||||
|
||||
err := os.WriteFile(includedPath, []byte(includedContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create included file: %v", err)
|
||||
}
|
||||
|
||||
// Create main file with quoted !include
|
||||
mainPath := filepath.Join(tempDir, "main.yaml")
|
||||
mainContent := `types:
|
||||
!include "file with spaces.yaml"`
|
||||
|
||||
err = os.WriteFile(mainPath, []byte(mainContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main file: %v", err)
|
||||
}
|
||||
|
||||
// Process includes
|
||||
result, err := processIncludes(mainPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to process includes: %v", err)
|
||||
}
|
||||
|
||||
// Check that include was processed
|
||||
if !strings.Contains(result, "production-srlinux") {
|
||||
t.Error("Expected included content to be present in result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessIncludesNonexistentFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create main file with include to nonexistent file
|
||||
mainPath := filepath.Join(tempDir, "main.yaml")
|
||||
mainContent := `types:
|
||||
!include nonexistent.yaml`
|
||||
|
||||
err := os.WriteFile(mainPath, []byte(mainContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main file: %v", err)
|
||||
}
|
||||
|
||||
// Process includes should fail
|
||||
_, err = processIncludes(mainPath)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent include file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigWithIncludes(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create device types file
|
||||
typesPath := filepath.Join(tempDir, "types.yaml")
|
||||
typesContent := `srlinux:
|
||||
commands:
|
||||
- show version
|
||||
- show platform linecard
|
||||
eos:
|
||||
commands:
|
||||
- show version
|
||||
- show inventory`
|
||||
|
||||
err := os.WriteFile(typesPath, []byte(typesContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create types file: %v", err)
|
||||
}
|
||||
|
||||
// Create main config file with includes
|
||||
mainPath := filepath.Join(tempDir, "config.yaml")
|
||||
mainContent := `types:
|
||||
!include types.yaml
|
||||
devices:
|
||||
asw100:
|
||||
user: admin
|
||||
type: srlinux
|
||||
edge-01:
|
||||
user: operator
|
||||
type: eos`
|
||||
|
||||
err = os.WriteFile(mainPath, []byte(mainContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config file: %v", err)
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
config, err := loadConfig(mainPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config with includes: %v", err)
|
||||
}
|
||||
|
||||
// Verify types were loaded correctly
|
||||
if len(config.Types) != 2 {
|
||||
t.Errorf("Expected 2 types, got %d", len(config.Types))
|
||||
}
|
||||
|
||||
srlinuxType, exists := config.Types["srlinux"]
|
||||
if !exists {
|
||||
t.Error("Expected 'srlinux' type to exist")
|
||||
}
|
||||
if len(srlinuxType.Commands) != 2 {
|
||||
t.Errorf("Expected 2 commands for srlinux type, got %d", len(srlinuxType.Commands))
|
||||
}
|
||||
|
||||
// Verify devices were loaded correctly
|
||||
if len(config.Devices) != 2 {
|
||||
t.Errorf("Expected 2 devices, got %d", len(config.Devices))
|
||||
}
|
||||
|
||||
asw100, exists := config.Devices["asw100"]
|
||||
if !exists {
|
||||
t.Error("Expected 'asw100' device to exist")
|
||||
}
|
||||
if asw100.User != "admin" {
|
||||
t.Errorf("Expected user 'admin', got '%s'", asw100.User)
|
||||
}
|
||||
if asw100.Type != "srlinux" {
|
||||
t.Errorf("Expected type 'srlinux', got '%s'", asw100.Type)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user