Files
clab-webserver/main.go

302 lines
7.7 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"text/template"
"text/template/parse"
"time"
"gopkg.in/yaml.v3"
)
type prefixEntry struct {
network *net.IPNet
ones int
vars map[string]string
}
func loadClientMap(path string) ([]prefixEntry, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
raw := map[string]map[string]string{}
if err := yaml.Unmarshal(data, &raw); err != nil {
return nil, err
}
var entries []prefixEntry
for cidr, vars := range raw {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
ones, _ := network.Mask.Size()
entries = append(entries, prefixEntry{network: network, ones: ones, vars: vars})
}
return entries, nil
}
type clientMap struct {
mu sync.RWMutex
entries []prefixEntry
}
func (cm *clientMap) get() []prefixEntry {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.entries
}
func (cm *clientMap) set(entries []prefixEntry) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.entries = entries
}
// watchClientMap polls path every 5 seconds and reloads when mtime changes.
// On parse/CIDR error it logs a warning and keeps the previous data.
func watchClientMap(path string, cm *clientMap, initialMod time.Time) {
lastMod := initialMod
for {
time.Sleep(5 * time.Second)
info, err := os.Stat(path)
if err != nil {
log.Printf("WARNING: client-map stat error: %v", err)
continue
}
if !info.ModTime().After(lastMod) {
continue
}
entries, err := loadClientMap(path)
if err != nil {
log.Printf("WARNING: client-map reload failed, keeping previous data: %v", err)
continue
}
cm.set(entries)
lastMod = info.ModTime()
log.Printf("client-map reloaded from %s (%d prefixes)", path, len(entries))
}
}
// mergedVars collects all matching prefix entries, sorts least-to-most specific,
// and merges their var maps so more-specific entries overwrite less-specific ones.
func mergedVars(entries []prefixEntry, ip net.IP) map[string]interface{} {
type match struct {
ones int
vars map[string]string
}
var matches []match
for _, e := range entries {
if e.network.Contains(ip) {
matches = append(matches, match{e.ones, e.vars})
}
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].ones < matches[j].ones
})
result := map[string]interface{}{}
for _, m := range matches {
for k, v := range m.vars {
result[k] = v
}
}
return result
}
// extractTemplateFields walks the parsed template AST and returns all top-level
// field names referenced as {{ .fieldname }}.
func extractTemplateFields(tmpl *template.Template) []string {
if tmpl.Tree == nil || tmpl.Tree.Root == nil {
return nil
}
seen := map[string]bool{}
var fields []string
walkNode(tmpl.Tree.Root, seen, &fields)
return fields
}
func walkNode(node parse.Node, seen map[string]bool, fields *[]string) {
if node == nil {
return
}
switch n := node.(type) {
case *parse.ListNode:
for _, child := range n.Nodes {
walkNode(child, seen, fields)
}
case *parse.ActionNode:
walkNode(n.Pipe, seen, fields)
case *parse.PipeNode:
for _, cmd := range n.Cmds {
for _, arg := range cmd.Args {
walkNode(arg, seen, fields)
}
}
case *parse.FieldNode:
if len(n.Ident) > 0 && !seen[n.Ident[0]] {
seen[n.Ident[0]] = true
*fields = append(*fields, n.Ident[0])
}
case *parse.IfNode:
walkNode(n.Pipe, seen, fields)
walkNode(n.List, seen, fields)
walkNode(n.ElseList, seen, fields)
case *parse.RangeNode:
walkNode(n.Pipe, seen, fields)
walkNode(n.List, seen, fields)
walkNode(n.ElseList, seen, fields)
case *parse.WithNode:
walkNode(n.Pipe, seen, fields)
walkNode(n.List, seen, fields)
walkNode(n.ElseList, seen, fields)
}
}
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
func (lrw *loggingResponseWriter) Write(b []byte) (int, error) {
n, err := lrw.ResponseWriter.Write(b)
lrw.bytesWritten += n
return n, err
}
func main() {
clientMapPath := flag.String("client-map", getEnv("CLIENT_MAP", "config/client-map.yaml"), "path to client-map YAML file [$CLIENT_MAP]")
listenAddr := flag.String("listen", getEnv("LISTEN", ":80"), "listen address [$LISTEN]")
docrootDir := flag.String("docroot", getEnv("DOCROOT", "docroot"), "path to docroot directory [$DOCROOT]")
flag.Parse()
entries, err := loadClientMap(*clientMapPath)
if err != nil {
log.Fatalf("loading client-map: %v", err)
}
info, err := os.Stat(*clientMapPath)
if err != nil {
log.Fatalf("stat client-map: %v", err)
}
cm := &clientMap{}
cm.set(entries)
go watchClientMap(*clientMapPath, cm, info.ModTime())
absDocroot, err := filepath.Abs(*docrootDir)
if err != nil {
log.Fatalf("resolving docroot: %v", err)
}
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: 200}
serveRequest(lrw, r, cm.get(), absDocroot)
fmt.Printf("%s - - [%s] \"%s %s %s\" %d %d \"-\" \"%s\"\n",
clientIP(r),
time.Now().Format("02/Jan/2006:15:04:05 -0700"),
r.Method, r.URL.RequestURI(), r.Proto,
lrw.statusCode, lrw.bytesWritten,
r.Header.Get("User-Agent"),
)
})
log.Printf("listening on %s, docroot=%s", *listenAddr, absDocroot)
log.Fatal(http.ListenAndServe(*listenAddr, mux))
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func clientIP(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func serveRequest(w http.ResponseWriter, r *http.Request, entries []prefixEntry, absDocroot string) {
// Parse and normalize client IP (handle IPv4-mapped IPv6 from dual-stack sockets)
host := clientIP(r)
ip := net.ParseIP(host)
if ip == nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
if v4 := ip.To4(); v4 != nil {
ip = v4
}
// Build template data from merged prefix matches; HTTP-derived vars always win
data := mergedVars(entries, ip)
data["remote_address"] = host
data["host"] = r.Host
data["user_agent"] = r.Header.Get("User-Agent")
// Resolve and validate file path
urlPath := r.URL.Path
if urlPath == "/" || urlPath == "" {
urlPath = "/index.txt"
}
relPath := strings.TrimPrefix(filepath.Clean(urlPath), "/")
absPath, err := filepath.Abs(filepath.Join(absDocroot, relPath))
if err != nil || !strings.HasPrefix(absPath+string(os.PathSeparator), absDocroot+string(os.PathSeparator)) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
content, err := os.ReadFile(absPath)
if err != nil {
if os.IsNotExist(err) {
http.Error(w, "Not Found", http.StatusNotFound)
} else {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
return
}
tmpl, err := template.New("").Parse(string(content))
if err != nil {
log.Printf("ERROR: template parse error for %s: %v", absPath, err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Warn for each template variable with no value for this client; default to ""
for _, field := range extractTemplateFields(tmpl) {
if _, ok := data[field]; !ok {
log.Printf("WARNING: template variable .%s has no value for client %s", field, host)
data[field] = ""
}
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
log.Printf("ERROR: template execute error for %s: %v", absPath, err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Write(buf.Bytes())
}