package main import ( "bytes" "flag" "fmt" "log" "net" "net/http" "os" "path/filepath" "sort" "strings" "sync" "text/template" "text/template/parse" "time" "gopkg.in/yaml.v3" ) type prefixEntry struct { network *net.IPNet ones int vars map[string]string } func loadClientMap(path string) ([]prefixEntry, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } raw := map[string]map[string]string{} if err := yaml.Unmarshal(data, &raw); err != nil { return nil, err } var entries []prefixEntry for cidr, vars := range raw { _, network, err := net.ParseCIDR(cidr) if err != nil { return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err) } ones, _ := network.Mask.Size() entries = append(entries, prefixEntry{network: network, ones: ones, vars: vars}) } return entries, nil } type clientMap struct { mu sync.RWMutex entries []prefixEntry } func (cm *clientMap) get() []prefixEntry { cm.mu.RLock() defer cm.mu.RUnlock() return cm.entries } func (cm *clientMap) set(entries []prefixEntry) { cm.mu.Lock() defer cm.mu.Unlock() cm.entries = entries } // watchClientMap polls path every 5 seconds and reloads when mtime changes. // On parse/CIDR error it logs a warning and keeps the previous data. func watchClientMap(path string, cm *clientMap, initialMod time.Time) { lastMod := initialMod for { time.Sleep(5 * time.Second) info, err := os.Stat(path) if err != nil { log.Printf("WARNING: client-map stat error: %v", err) continue } if !info.ModTime().After(lastMod) { continue } entries, err := loadClientMap(path) if err != nil { log.Printf("WARNING: client-map reload failed, keeping previous data: %v", err) continue } cm.set(entries) lastMod = info.ModTime() log.Printf("client-map reloaded from %s (%d prefixes)", path, len(entries)) } } // mergedVars collects all matching prefix entries, sorts least-to-most specific, // and merges their var maps so more-specific entries overwrite less-specific ones. func mergedVars(entries []prefixEntry, ip net.IP) map[string]interface{} { type match struct { ones int vars map[string]string } var matches []match for _, e := range entries { if e.network.Contains(ip) { matches = append(matches, match{e.ones, e.vars}) } } sort.Slice(matches, func(i, j int) bool { return matches[i].ones < matches[j].ones }) result := map[string]interface{}{} for _, m := range matches { for k, v := range m.vars { result[k] = v } } return result } // extractTemplateFields walks the parsed template AST and returns all top-level // field names referenced as {{ .fieldname }}. func extractTemplateFields(tmpl *template.Template) []string { if tmpl.Tree == nil || tmpl.Tree.Root == nil { return nil } seen := map[string]bool{} var fields []string walkNode(tmpl.Tree.Root, seen, &fields) return fields } func walkNode(node parse.Node, seen map[string]bool, fields *[]string) { if node == nil { return } switch n := node.(type) { case *parse.ListNode: for _, child := range n.Nodes { walkNode(child, seen, fields) } case *parse.ActionNode: walkNode(n.Pipe, seen, fields) case *parse.PipeNode: for _, cmd := range n.Cmds { for _, arg := range cmd.Args { walkNode(arg, seen, fields) } } case *parse.FieldNode: if len(n.Ident) > 0 && !seen[n.Ident[0]] { seen[n.Ident[0]] = true *fields = append(*fields, n.Ident[0]) } case *parse.IfNode: walkNode(n.Pipe, seen, fields) walkNode(n.List, seen, fields) walkNode(n.ElseList, seen, fields) case *parse.RangeNode: walkNode(n.Pipe, seen, fields) walkNode(n.List, seen, fields) walkNode(n.ElseList, seen, fields) case *parse.WithNode: walkNode(n.Pipe, seen, fields) walkNode(n.List, seen, fields) walkNode(n.ElseList, seen, fields) } } type loggingResponseWriter struct { http.ResponseWriter statusCode int bytesWritten int } func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } func (lrw *loggingResponseWriter) Write(b []byte) (int, error) { n, err := lrw.ResponseWriter.Write(b) lrw.bytesWritten += n return n, err } func main() { clientMapPath := flag.String("client-map", getEnv("CLIENT_MAP", "client-map.yaml"), "path to client-map YAML file [$CLIENT_MAP]") listenAddr := flag.String("listen", getEnv("LISTEN", ":80"), "listen address [$LISTEN]") docrootDir := flag.String("docroot", getEnv("DOCROOT", "docroot"), "path to docroot directory [$DOCROOT]") flag.Parse() entries, err := loadClientMap(*clientMapPath) if err != nil { log.Fatalf("loading client-map: %v", err) } info, err := os.Stat(*clientMapPath) if err != nil { log.Fatalf("stat client-map: %v", err) } cm := &clientMap{} cm.set(entries) go watchClientMap(*clientMapPath, cm, info.ModTime()) absDocroot, err := filepath.Abs(*docrootDir) if err != nil { log.Fatalf("resolving docroot: %v", err) } mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: 200} serveRequest(lrw, r, cm.get(), absDocroot) fmt.Printf("%s - - [%s] \"%s %s %s\" %d %d \"-\" \"%s\"\n", clientIP(r), time.Now().Format("02/Jan/2006:15:04:05 -0700"), r.Method, r.URL.RequestURI(), r.Proto, lrw.statusCode, lrw.bytesWritten, r.Header.Get("User-Agent"), ) }) log.Printf("listening on %s, docroot=%s", *listenAddr, absDocroot) log.Fatal(http.ListenAndServe(*listenAddr, mux)) } func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } func clientIP(r *http.Request) string { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } func serveRequest(w http.ResponseWriter, r *http.Request, entries []prefixEntry, absDocroot string) { // Parse and normalize client IP (handle IPv4-mapped IPv6 from dual-stack sockets) host := clientIP(r) ip := net.ParseIP(host) if ip == nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } if v4 := ip.To4(); v4 != nil { ip = v4 } // Build template data from merged prefix matches; HTTP-derived vars always win data := mergedVars(entries, ip) data["remote_address"] = host data["host"] = r.Host data["user_agent"] = r.Header.Get("User-Agent") // Resolve and validate file path urlPath := r.URL.Path if urlPath == "/" || urlPath == "" { urlPath = "/index.txt" } relPath := strings.TrimPrefix(filepath.Clean(urlPath), "/") absPath, err := filepath.Abs(filepath.Join(absDocroot, relPath)) if err != nil || !strings.HasPrefix(absPath+string(os.PathSeparator), absDocroot+string(os.PathSeparator)) { http.Error(w, "Forbidden", http.StatusForbidden) return } content, err := os.ReadFile(absPath) if err != nil { if os.IsNotExist(err) { http.Error(w, "Not Found", http.StatusNotFound) } else { http.Error(w, "Internal Server Error", http.StatusInternalServerError) } return } tmpl, err := template.New("").Parse(string(content)) if err != nil { log.Printf("ERROR: template parse error for %s: %v", absPath, err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Warn for each template variable with no value for this client; default to "" for _, field := range extractTemplateFields(tmpl) { if _, ok := data[field]; !ok { log.Printf("WARNING: template variable .%s has no value for client %s", field, host) data[field] = "" } } var buf bytes.Buffer if err := tmpl.Execute(&buf, data); err != nil { log.Printf("ERROR: template execute error for %s: %v", absPath, err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Write(buf.Bytes()) }