From 1a576b16c14038d333089e1030c13e66dfdc7164 Mon Sep 17 00:00:00 2001 From: Epid Date: Tue, 24 Mar 2026 04:44:00 +0300 Subject: [PATCH] fix(websocket): address silverwind review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move /-/ws route inside reqSignIn middleware group; remove manual ctx.IsSigned check from handler (auth is now enforced by the router) - Fix scheduleReconnect() to schedule using current delay then double, so first reconnect fires after 50ms not 100ms (reported by silverwind) - Replace sourcesByPort.set(port, null) with delete() to prevent MessagePort retention after tab close (memory leak fix) - Centralize topic naming in pubsub.UserTopic() — removes duplication between the notifier and the WebSocket handler - Skip DB polling in notifier when broker has no active subscribers to avoid unnecessary load on idle instances - Hold RLock for the full Publish fan-out loop to prevent a race where cancel() closes a channel between slice read and send --- routers/web/web.go | 4 ++- routers/web/websocket/websocket.go | 11 ++------ services/pubsub/broker.go | 26 ++++++++++++++++--- services/websocket/notifier.go | 12 ++++----- web_src/js/features/websocket.sharedworker.ts | 15 ++++++----- 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/routers/web/web.go b/routers/web/web.go index 6cf209a886..2658f4b40d 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -593,7 +593,9 @@ func registerWebRoutes(m *web.Router, webAuth *AuthMiddleware) { }, reqSignOut) m.Any("/user/events", routing.MarkLongPolling, events.Events) - m.Get("/-/ws", gitea_websocket.Serve) + m.Group("", func() { + m.Get("/-/ws", gitea_websocket.Serve) + }, reqSignIn) m.Group("/login/oauth", func() { m.Group("", func() { diff --git a/routers/web/websocket/websocket.go b/routers/web/websocket/websocket.go index 6feb81008d..cfa146e347 100644 --- a/routers/web/websocket/websocket.go +++ b/routers/web/websocket/websocket.go @@ -4,8 +4,6 @@ package websocket import ( - "fmt" - "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/pubsub" @@ -14,12 +12,8 @@ import ( ) // Serve handles WebSocket upgrade and event delivery for the signed-in user. +// Authentication is enforced by the reqSignIn middleware in the router. func Serve(ctx *context.Context) { - if !ctx.IsSigned { - ctx.Status(401) - return - } - conn, err := gitea_ws.Accept(ctx.Resp, ctx.Req, &gitea_ws.AcceptOptions{ InsecureSkipVerify: false, }) @@ -29,8 +23,7 @@ func Serve(ctx *context.Context) { } defer conn.CloseNow() //nolint:errcheck // CloseNow is best-effort; error is intentionally ignored - topic := fmt.Sprintf("user-%d", ctx.Doer.ID) - ch, cancel := pubsub.DefaultBroker.Subscribe(topic) + ch, cancel := pubsub.DefaultBroker.Subscribe(pubsub.UserTopic(ctx.Doer.ID)) defer cancel() wsCtx := ctx.Req.Context() diff --git a/services/pubsub/broker.go b/services/pubsub/broker.go index c2f2dc1026..9143742489 100644 --- a/services/pubsub/broker.go +++ b/services/pubsub/broker.go @@ -4,6 +4,7 @@ package pubsub import ( + "fmt" "sync" ) @@ -48,14 +49,33 @@ func (b *Broker) Subscribe(topic string) (<-chan []byte, func()) { return ch, cancel } +// UserTopic returns the pub/sub topic name for a given user ID. +// Centralised here so the notifier and the WebSocket handler always agree on the format. +func UserTopic(userID int64) string { + return fmt.Sprintf("user-%d", userID) +} + +// HasSubscribers reports whether the broker has at least one active subscriber across all topics. +func (b *Broker) HasSubscribers() bool { + b.mu.RLock() + defer b.mu.RUnlock() + for _, subs := range b.subs { + if len(subs) > 0 { + return true + } + } + return false +} + // Publish sends msg to all subscribers of topic. // Non-blocking: slow subscribers are skipped. +// The RLock is held for the entire fan-out to prevent a race where cancel() +// closes a channel between the slice read and the send. func (b *Broker) Publish(topic string, msg []byte) { b.mu.RLock() - subs := b.subs[topic] - b.mu.RUnlock() + defer b.mu.RUnlock() - for _, ch := range subs { + for _, ch := range b.subs[topic] { select { case ch <- msg: default: diff --git a/services/websocket/notifier.go b/services/websocket/notifier.go index af00d75d03..d8f64ef7d5 100644 --- a/services/websocket/notifier.go +++ b/services/websocket/notifier.go @@ -5,7 +5,6 @@ package websocket import ( "context" - "fmt" "time" activities_model "code.gitea.io/gitea/models/activities" @@ -29,10 +28,6 @@ type notificationCountEvent struct { Count int64 `json:"count"` } -func userTopic(userID int64) string { - return fmt.Sprintf("user-%d", userID) -} - // Init starts the background goroutine that polls notification counts // and pushes updates to connected WebSocket clients. func Init() error { @@ -57,6 +52,11 @@ func run(ctx context.Context) { case <-ctx.Done(): return case <-timer.C: + if !pubsub.DefaultBroker.HasSubscribers() { + then = nowTS().Add(-2) + continue + } + now := nowTS().Add(-2) uidCounts, err := activities_model.GetUIDsAndNotificationCounts(ctx, then, now) @@ -73,7 +73,7 @@ func run(ctx context.Context) { if err != nil { continue } - pubsub.DefaultBroker.Publish(userTopic(uidCount.UserID), msg) + pubsub.DefaultBroker.Publish(pubsub.UserTopic(uidCount.UserID), msg) } then = now diff --git a/web_src/js/features/websocket.sharedworker.ts b/web_src/js/features/websocket.sharedworker.ts index 88d4870f01..491c7f2a07 100644 --- a/web_src/js/features/websocket.sharedworker.ts +++ b/web_src/js/features/websocket.sharedworker.ts @@ -52,11 +52,12 @@ class WsSource { scheduleReconnect() { if (this.clients.length === 0 || this.reconnectTimer !== null) return; - this.reconnectDelay = Math.min(this.reconnectDelay * 2, RECONNECT_DELAY_MAX); + const delay = this.reconnectDelay; this.reconnectTimer = setTimeout(() => { this.reconnectTimer = null; this.connect(); - }, this.reconnectDelay); + }, delay); + this.reconnectDelay = Math.min(this.reconnectDelay * 2, RECONNECT_DELAY_MAX); } register(port: MessagePort) { @@ -87,8 +88,8 @@ class WsSource { } } -const sourcesByUrl = new Map(); -const sourcesByPort = new Map(); +const sourcesByUrl = new Map(); +const sourcesByPort = new Map(); (self as unknown as SharedWorkerGlobalScope).addEventListener('connect', (e: MessageEvent) => { for (const port of e.ports) { @@ -106,7 +107,7 @@ const sourcesByPort = new Map(); const count = source.deregister(port); if (count === 0) { source.close(); - sourcesByUrl.set(source.url, null); + sourcesByUrl.delete(source.url); } } source = new WsSource(url); @@ -119,8 +120,8 @@ const sourcesByPort = new Map(); const count = source.deregister(port); if (count === 0) { source.close(); - sourcesByUrl.set(source.url, null); - sourcesByPort.set(port, null); + sourcesByUrl.delete(source.url); + sourcesByPort.delete(port); } } else if (event.data.type === 'status') { const source = sourcesByPort.get(port);