A toy webserver with variable expansion based on connecting client
This commit is contained in:
301
main.go
Normal file
301
main.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type prefixEntry struct {
|
||||
network *net.IPNet
|
||||
ones int
|
||||
vars map[string]string
|
||||
}
|
||||
|
||||
func loadClientMap(path string) ([]prefixEntry, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw := map[string]map[string]string{}
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entries []prefixEntry
|
||||
for cidr, vars := range raw {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err)
|
||||
}
|
||||
ones, _ := network.Mask.Size()
|
||||
entries = append(entries, prefixEntry{network: network, ones: ones, vars: vars})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
type clientMap struct {
|
||||
mu sync.RWMutex
|
||||
entries []prefixEntry
|
||||
}
|
||||
|
||||
func (cm *clientMap) get() []prefixEntry {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return cm.entries
|
||||
}
|
||||
|
||||
func (cm *clientMap) set(entries []prefixEntry) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.entries = entries
|
||||
}
|
||||
|
||||
// watchClientMap polls path every 5 seconds and reloads when mtime changes.
|
||||
// On parse/CIDR error it logs a warning and keeps the previous data.
|
||||
func watchClientMap(path string, cm *clientMap, initialMod time.Time) {
|
||||
lastMod := initialMod
|
||||
for {
|
||||
time.Sleep(5 * time.Second)
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
log.Printf("WARNING: client-map stat error: %v", err)
|
||||
continue
|
||||
}
|
||||
if !info.ModTime().After(lastMod) {
|
||||
continue
|
||||
}
|
||||
entries, err := loadClientMap(path)
|
||||
if err != nil {
|
||||
log.Printf("WARNING: client-map reload failed, keeping previous data: %v", err)
|
||||
continue
|
||||
}
|
||||
cm.set(entries)
|
||||
lastMod = info.ModTime()
|
||||
log.Printf("client-map reloaded from %s (%d prefixes)", path, len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
// mergedVars collects all matching prefix entries, sorts least-to-most specific,
|
||||
// and merges their var maps so more-specific entries overwrite less-specific ones.
|
||||
func mergedVars(entries []prefixEntry, ip net.IP) map[string]interface{} {
|
||||
type match struct {
|
||||
ones int
|
||||
vars map[string]string
|
||||
}
|
||||
var matches []match
|
||||
for _, e := range entries {
|
||||
if e.network.Contains(ip) {
|
||||
matches = append(matches, match{e.ones, e.vars})
|
||||
}
|
||||
}
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].ones < matches[j].ones
|
||||
})
|
||||
result := map[string]interface{}{}
|
||||
for _, m := range matches {
|
||||
for k, v := range m.vars {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// extractTemplateFields walks the parsed template AST and returns all top-level
|
||||
// field names referenced as {{ .fieldname }}.
|
||||
func extractTemplateFields(tmpl *template.Template) []string {
|
||||
if tmpl.Tree == nil || tmpl.Tree.Root == nil {
|
||||
return nil
|
||||
}
|
||||
seen := map[string]bool{}
|
||||
var fields []string
|
||||
walkNode(tmpl.Tree.Root, seen, &fields)
|
||||
return fields
|
||||
}
|
||||
|
||||
func walkNode(node parse.Node, seen map[string]bool, fields *[]string) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
switch n := node.(type) {
|
||||
case *parse.ListNode:
|
||||
for _, child := range n.Nodes {
|
||||
walkNode(child, seen, fields)
|
||||
}
|
||||
case *parse.ActionNode:
|
||||
walkNode(n.Pipe, seen, fields)
|
||||
case *parse.PipeNode:
|
||||
for _, cmd := range n.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
walkNode(arg, seen, fields)
|
||||
}
|
||||
}
|
||||
case *parse.FieldNode:
|
||||
if len(n.Ident) > 0 && !seen[n.Ident[0]] {
|
||||
seen[n.Ident[0]] = true
|
||||
*fields = append(*fields, n.Ident[0])
|
||||
}
|
||||
case *parse.IfNode:
|
||||
walkNode(n.Pipe, seen, fields)
|
||||
walkNode(n.List, seen, fields)
|
||||
walkNode(n.ElseList, seen, fields)
|
||||
case *parse.RangeNode:
|
||||
walkNode(n.Pipe, seen, fields)
|
||||
walkNode(n.List, seen, fields)
|
||||
walkNode(n.ElseList, seen, fields)
|
||||
case *parse.WithNode:
|
||||
walkNode(n.Pipe, seen, fields)
|
||||
walkNode(n.List, seen, fields)
|
||||
walkNode(n.ElseList, seen, fields)
|
||||
}
|
||||
}
|
||||
|
||||
type loggingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
bytesWritten int
|
||||
}
|
||||
|
||||
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
||||
lrw.statusCode = code
|
||||
lrw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (lrw *loggingResponseWriter) Write(b []byte) (int, error) {
|
||||
n, err := lrw.ResponseWriter.Write(b)
|
||||
lrw.bytesWritten += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func main() {
|
||||
clientMapPath := flag.String("client-map", getEnv("CLIENT_MAP", "client-map.yaml"), "path to client-map YAML file [$CLIENT_MAP]")
|
||||
listenAddr := flag.String("listen", getEnv("LISTEN", ":80"), "listen address [$LISTEN]")
|
||||
docrootDir := flag.String("docroot", getEnv("DOCROOT", "docroot"), "path to docroot directory [$DOCROOT]")
|
||||
flag.Parse()
|
||||
|
||||
entries, err := loadClientMap(*clientMapPath)
|
||||
if err != nil {
|
||||
log.Fatalf("loading client-map: %v", err)
|
||||
}
|
||||
info, err := os.Stat(*clientMapPath)
|
||||
if err != nil {
|
||||
log.Fatalf("stat client-map: %v", err)
|
||||
}
|
||||
cm := &clientMap{}
|
||||
cm.set(entries)
|
||||
go watchClientMap(*clientMapPath, cm, info.ModTime())
|
||||
|
||||
absDocroot, err := filepath.Abs(*docrootDir)
|
||||
if err != nil {
|
||||
log.Fatalf("resolving docroot: %v", err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: 200}
|
||||
serveRequest(lrw, r, cm.get(), absDocroot)
|
||||
fmt.Printf("%s - - [%s] \"%s %s %s\" %d %d \"-\" \"%s\"\n",
|
||||
clientIP(r),
|
||||
time.Now().Format("02/Jan/2006:15:04:05 -0700"),
|
||||
r.Method, r.URL.RequestURI(), r.Proto,
|
||||
lrw.statusCode, lrw.bytesWritten,
|
||||
r.Header.Get("User-Agent"),
|
||||
)
|
||||
})
|
||||
|
||||
log.Printf("listening on %s, docroot=%s", *listenAddr, absDocroot)
|
||||
log.Fatal(http.ListenAndServe(*listenAddr, mux))
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func serveRequest(w http.ResponseWriter, r *http.Request, entries []prefixEntry, absDocroot string) {
|
||||
// Parse and normalize client IP (handle IPv4-mapped IPv6 from dual-stack sockets)
|
||||
host := clientIP(r)
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
ip = v4
|
||||
}
|
||||
|
||||
// Build template data from merged prefix matches; HTTP-derived vars always win
|
||||
data := mergedVars(entries, ip)
|
||||
data["remote_address"] = host
|
||||
data["host"] = r.Host
|
||||
data["user_agent"] = r.Header.Get("User-Agent")
|
||||
|
||||
// Resolve and validate file path
|
||||
urlPath := r.URL.Path
|
||||
if urlPath == "/" || urlPath == "" {
|
||||
urlPath = "/index.txt"
|
||||
}
|
||||
relPath := strings.TrimPrefix(filepath.Clean(urlPath), "/")
|
||||
absPath, err := filepath.Abs(filepath.Join(absDocroot, relPath))
|
||||
if err != nil || !strings.HasPrefix(absPath+string(os.PathSeparator), absDocroot+string(os.PathSeparator)) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
} else {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Parse(string(content))
|
||||
if err != nil {
|
||||
log.Printf("ERROR: template parse error for %s: %v", absPath, err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Warn for each template variable with no value for this client; default to ""
|
||||
for _, field := range extractTemplateFields(tmpl) {
|
||||
if _, ok := data[field]; !ok {
|
||||
log.Printf("WARNING: template variable .%s has no value for client %s", field, host)
|
||||
data[field] = ""
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
log.Printf("ERROR: template execute error for %s: %v", absPath, err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Write(buf.Bytes())
|
||||
}
|
||||
Reference in New Issue
Block a user