add yaml include feature

This commit is contained in:
2025-07-06 12:27:49 +00:00
parent 9e0469e016
commit 8198b90e60
7 changed files with 453 additions and 111 deletions

View File

@ -9,6 +9,8 @@ import (
"net"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/spf13/cobra"
@ -180,15 +182,15 @@ func (rb *RouterBackup) Disconnect() {
}
}
// loadConfig loads the YAML configuration file
// loadConfig loads the YAML configuration file with !include support
func loadConfig(configPath string) (*Config, error) {
data, err := ioutil.ReadFile(configPath)
processedYAML, err := processIncludes(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %v", configPath, err)
return nil, fmt.Errorf("failed to process includes: %v", err)
}
var config Config
err = yaml.Unmarshal(data, &config)
err = yaml.Unmarshal([]byte(processedYAML), &config)
if err != nil {
return nil, fmt.Errorf("failed to parse YAML: %v", err)
}
@ -196,6 +198,75 @@ func loadConfig(configPath string) (*Config, error) {
return &config, nil
}
// processIncludes processes YAML files with !include directives (one level deep)
func processIncludes(filePath string) (string, error) {
// Read the file
data, err := ioutil.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("failed to read file %s: %v", filePath, err)
}
content := string(data)
// Process !include directives
// Match patterns like: !include path/to/file.yaml (excluding commented lines)
includeRegex := regexp.MustCompile(`(?m)^(\s*)!include\s+(.+)$`)
baseDir := filepath.Dir(filePath)
// Process includes line by line to avoid conflicts
lines := strings.Split(content, "\n")
var resultLines []string
for _, line := range lines {
// Check if this line matches our include pattern
if match := includeRegex.FindStringSubmatch(line); match != nil {
leadingWhitespace := match[1]
includePath := strings.TrimSpace(match[2])
// Skip commented lines
if strings.Contains(strings.TrimSpace(line), "#") && strings.Index(strings.TrimSpace(line), "#") < strings.Index(strings.TrimSpace(line), "!include") {
resultLines = append(resultLines, line)
continue
}
// Remove quotes if present
includePath = strings.Trim(includePath, "\"'")
// Make path relative to current config file
if !filepath.IsAbs(includePath) {
includePath = filepath.Join(baseDir, includePath)
}
// Read the included file
includedData, err := ioutil.ReadFile(includePath)
if err != nil {
return "", fmt.Errorf("failed to read include file %s: %v", includePath, err)
}
// Use the captured leading whitespace as indentation prefix
indentPrefix := leadingWhitespace
// Indent each line of included content to match the !include line's indentation
includedLines := strings.Split(string(includedData), "\n")
for _, includeLine := range includedLines {
if strings.TrimSpace(includeLine) == "" {
resultLines = append(resultLines, "")
} else {
resultLines = append(resultLines, indentPrefix+includeLine)
}
}
} else {
// Regular line, just copy it
resultLines = append(resultLines, line)
}
}
content = strings.Join(resultLines, "\n")
return content, nil
}
// findDefaultSSHKey looks for default SSH keys
func findDefaultSSHKey() string {
homeDir, err := os.UserHomeDir()

View File

@ -5,6 +5,7 @@ package main
import (
"os"
"path/filepath"
"strings"
"testing"
)
@ -71,15 +72,25 @@ func TestFindDefaultSSHKeyNotFound(t *testing.T) {
}
func TestLoadConfig(t *testing.T) {
// Create a temporary config file
// Create a temporary directory and files
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "test-config.yaml")
// Create device-types.yaml file
deviceTypesPath := filepath.Join(tempDir, "device-types.yaml")
deviceTypesContent := `test-type:
commands:
- show version
- show status`
err := os.WriteFile(deviceTypesPath, []byte(deviceTypesContent), 0644)
if err != nil {
t.Fatalf("Failed to create device-types file: %v", err)
}
// Create main config file with !include
configPath := filepath.Join(tempDir, "test-config.yaml")
configContent := `types:
test-type:
commands:
- show version
- show status
!include device-types.yaml
devices:
test-device:
@ -91,7 +102,7 @@ devices:
- direct command
`
err := os.WriteFile(configPath, []byte(configContent), 0644)
err = os.WriteFile(configPath, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
@ -285,3 +296,179 @@ func TestRouterBackupCreation(t *testing.T) {
})
}
}
// Test !include functionality
func TestProcessIncludes(t *testing.T) {
tempDir := t.TempDir()
// Create included file
includedPath := filepath.Join(tempDir, "included.yaml")
includedContent := `test-type:
commands:
- show version
- show status`
err := os.WriteFile(includedPath, []byte(includedContent), 0644)
if err != nil {
t.Fatalf("Failed to create included file: %v", err)
}
// Create main file with !include
mainPath := filepath.Join(tempDir, "main.yaml")
mainContent := `types:
!include included.yaml
devices:
test-device:
user: testuser
type: test-type`
err = os.WriteFile(mainPath, []byte(mainContent), 0644)
if err != nil {
t.Fatalf("Failed to create main file: %v", err)
}
// Process includes
result, err := processIncludes(mainPath)
if err != nil {
t.Fatalf("Failed to process includes: %v", err)
}
// Check that include was processed
if !strings.Contains(result, "show version") {
t.Error("Expected included content to be present in result")
}
if !strings.Contains(result, "show status") {
t.Error("Expected included content to be present in result")
}
if strings.Contains(result, "!include") {
t.Error("Expected !include directive to be replaced")
}
}
func TestProcessIncludesWithQuotes(t *testing.T) {
tempDir := t.TempDir()
// Create included file with spaces in name
includedPath := filepath.Join(tempDir, "file with spaces.yaml")
includedContent := `production-srlinux:
commands:
- show version`
err := os.WriteFile(includedPath, []byte(includedContent), 0644)
if err != nil {
t.Fatalf("Failed to create included file: %v", err)
}
// Create main file with quoted !include
mainPath := filepath.Join(tempDir, "main.yaml")
mainContent := `types:
!include "file with spaces.yaml"`
err = os.WriteFile(mainPath, []byte(mainContent), 0644)
if err != nil {
t.Fatalf("Failed to create main file: %v", err)
}
// Process includes
result, err := processIncludes(mainPath)
if err != nil {
t.Fatalf("Failed to process includes: %v", err)
}
// Check that include was processed
if !strings.Contains(result, "production-srlinux") {
t.Error("Expected included content to be present in result")
}
}
func TestProcessIncludesNonexistentFile(t *testing.T) {
tempDir := t.TempDir()
// Create main file with include to nonexistent file
mainPath := filepath.Join(tempDir, "main.yaml")
mainContent := `types:
!include nonexistent.yaml`
err := os.WriteFile(mainPath, []byte(mainContent), 0644)
if err != nil {
t.Fatalf("Failed to create main file: %v", err)
}
// Process includes should fail
_, err = processIncludes(mainPath)
if err == nil {
t.Error("Expected error for nonexistent include file")
}
}
func TestLoadConfigWithIncludes(t *testing.T) {
tempDir := t.TempDir()
// Create device types file
typesPath := filepath.Join(tempDir, "types.yaml")
typesContent := `srlinux:
commands:
- show version
- show platform linecard
eos:
commands:
- show version
- show inventory`
err := os.WriteFile(typesPath, []byte(typesContent), 0644)
if err != nil {
t.Fatalf("Failed to create types file: %v", err)
}
// Create main config file with includes
mainPath := filepath.Join(tempDir, "config.yaml")
mainContent := `types:
!include types.yaml
devices:
asw100:
user: admin
type: srlinux
edge-01:
user: operator
type: eos`
err = os.WriteFile(mainPath, []byte(mainContent), 0644)
if err != nil {
t.Fatalf("Failed to create main config file: %v", err)
}
// Load configuration
config, err := loadConfig(mainPath)
if err != nil {
t.Fatalf("Failed to load config with includes: %v", err)
}
// Verify types were loaded correctly
if len(config.Types) != 2 {
t.Errorf("Expected 2 types, got %d", len(config.Types))
}
srlinuxType, exists := config.Types["srlinux"]
if !exists {
t.Error("Expected 'srlinux' type to exist")
}
if len(srlinuxType.Commands) != 2 {
t.Errorf("Expected 2 commands for srlinux type, got %d", len(srlinuxType.Commands))
}
// Verify devices were loaded correctly
if len(config.Devices) != 2 {
t.Errorf("Expected 2 devices, got %d", len(config.Devices))
}
asw100, exists := config.Devices["asw100"]
if !exists {
t.Error("Expected 'asw100' device to exist")
}
if asw100.User != "admin" {
t.Errorf("Expected user 'admin', got '%s'", asw100.User)
}
if asw100.Type != "srlinux" {
t.Errorf("Expected type 'srlinux', got '%s'", asw100.Type)
}
}