From d8ad89d11543b4ec1793bc47585b1cbc94c4a179 Mon Sep 17 00:00:00 2001 From: Pim van Pelt Date: Sat, 11 Apr 2026 02:18:44 +0200 Subject: [PATCH] When the server exits (^C or because docker/systemd exits it), streaming gRPC clients must be closed. Currently, the server does not exit until the gRPC client disconnects. --- cmd/maglevd/main.go | 2 +- internal/grpcapi/server.go | 11 +++++-- internal/grpcapi/server_test.go | 54 +++++++++++++++++++++++++-------- 3 files changed, 50 insertions(+), 17 deletions(-) diff --git a/cmd/maglevd/main.go b/cmd/maglevd/main.go index d393fcd..f9caa6e 100644 --- a/cmd/maglevd/main.go +++ b/cmd/maglevd/main.go @@ -65,7 +65,7 @@ func run() error { return fmt.Errorf("listen %s: %w", *grpcAddr, err) } srv := grpc.NewServer() - grpcapi.RegisterMaglevServer(srv, grpcapi.NewServer(chkr)) + grpcapi.RegisterMaglevServer(srv, grpcapi.NewServer(ctx, chkr)) slog.Info("grpc-listening", "addr", *grpcAddr) go func() { diff --git a/internal/grpcapi/server.go b/internal/grpcapi/server.go index dc08832..ab55b69 100644 --- a/internal/grpcapi/server.go +++ b/internal/grpcapi/server.go @@ -17,12 +17,15 @@ import ( // Server implements the MaglevServer gRPC interface. type Server struct { UnimplementedMaglevServer + ctx context.Context checker *checker.Checker } -// NewServer creates a Server backed by the given Checker. -func NewServer(c *checker.Checker) *Server { - return &Server{checker: c} +// NewServer creates a Server backed by the given Checker. The provided context +// controls the lifetime of streaming RPCs: cancelling it closes all active +// WatchBackendEvents streams so that grpc.Server.GracefulStop can complete. +func NewServer(ctx context.Context, c *checker.Checker) *Server { + return &Server{ctx: ctx, checker: c} } // ListFrontends returns the names of all configured frontends. @@ -112,6 +115,8 @@ func (s *Server) WatchBackendEvents(_ *WatchRequest, stream Maglev_WatchBackendE for { select { + case <-s.ctx.Done(): + return status.Error(codes.Unavailable, "server shutting down") case <-stream.Context().Done(): return nil case e, ok := <-ch: diff --git a/internal/grpcapi/server_test.go b/internal/grpcapi/server_test.go index 97a5b06..f239919 100644 --- a/internal/grpcapi/server_test.go +++ b/internal/grpcapi/server_test.go @@ -52,14 +52,14 @@ func makeTestChecker(ctx context.Context) *checker.Checker { return c } -func startTestServer(t *testing.T, c *checker.Checker) (MaglevClient, func()) { +func startTestServer(t *testing.T, ctx context.Context, c *checker.Checker) (MaglevClient, func()) { t.Helper() lis, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } srv := grpc.NewServer() - RegisterMaglevServer(srv, NewServer(c)) + RegisterMaglevServer(srv, NewServer(ctx, c)) go srv.Serve(lis) //nolint:errcheck conn, err := grpc.NewClient(lis.Addr().String(), @@ -78,7 +78,7 @@ func TestListFrontends(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() resp, err := client.ListFrontends(ctx, &ListFrontendsRequest{}) @@ -95,7 +95,7 @@ func TestGetFrontend(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() info, err := client.GetFrontend(ctx, &GetFrontendRequest{Name:"web"}) @@ -118,7 +118,7 @@ func TestGetFrontendNotFound(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() _, err := client.GetFrontend(ctx, &GetFrontendRequest{Name:"nope"}) @@ -132,7 +132,7 @@ func TestListBackends(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() resp, err := client.ListBackends(ctx, &ListBackendsRequest{}) @@ -149,7 +149,7 @@ func TestGetBackend(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() info, err := client.GetBackend(ctx, &GetBackendRequest{Name:"be0"}) @@ -175,7 +175,7 @@ func TestGetBackendNotFound(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() _, err := client.GetBackend(ctx, &GetBackendRequest{Name:"nope"}) @@ -189,7 +189,7 @@ func TestPauseResumeBackend(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() info, err := client.PauseBackend(ctx, &PauseResumeRequest{Name:"be0"}) @@ -214,7 +214,7 @@ func TestListHealthChecks(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() resp, err := client.ListHealthChecks(ctx, &ListHealthChecksRequest{}) @@ -231,7 +231,7 @@ func TestGetHealthCheck(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() info, err := client.GetHealthCheck(ctx, &GetHealthCheckRequest{Name: "icmp"}) @@ -251,7 +251,7 @@ func TestGetHealthCheckNotFound(t *testing.T) { defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() _, err := client.GetHealthCheck(ctx, &GetHealthCheckRequest{Name: "nope"}) @@ -260,12 +260,40 @@ func TestGetHealthCheckNotFound(t *testing.T) { } } +func TestWatchBackendEventsServerShutdown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := makeTestChecker(ctx) + + // Use a separate server context so we can cancel it independently. + srvCtx, srvCancel := context.WithCancel(ctx) + client, cleanup := startTestServer(t, srvCtx, c) + defer cleanup() + + stream, err := client.WatchBackendEvents(ctx, &WatchRequest{}) + if err != nil { + t.Fatalf("WatchBackendEvents: %v", err) + } + // Drain the initial synthetic event. + if _, err := stream.Recv(); err != nil { + t.Fatalf("initial Recv: %v", err) + } + + // Cancel the server context; the stream must terminate. + srvCancel() + _, err = stream.Recv() + if err == nil { + t.Fatal("expected stream to close after server shutdown, got nil error") + } +} + func TestWatchBackendEvents(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() c := makeTestChecker(ctx) - client, cleanup := startTestServer(t, c) + client, cleanup := startTestServer(t, ctx, c) defer cleanup() stream, err := client.WatchBackendEvents(ctx, &WatchRequest{})