diff --git a/routers/init.go b/routers/init.go index 2ed7a57e5c..f6775dd8fe 100644 --- a/routers/init.go +++ b/routers/init.go @@ -54,6 +54,7 @@ import ( "code.gitea.io/gitea/services/task" "code.gitea.io/gitea/services/uinotification" "code.gitea.io/gitea/services/webhook" + websocket_service "code.gitea.io/gitea/services/websocket" ) func mustInit(fn func() error) { @@ -160,6 +161,7 @@ func InitWebInstalled(ctx context.Context) { mustInit(task.Init) mustInit(repo_migrations.Init) eventsource.GetManager().Init() + mustInit(websocket_service.Init) mustInitCtx(ctx, mailer_incoming.Init) mustInitCtx(ctx, syncAppConfForGit) diff --git a/routers/web/web.go b/routers/web/web.go index a76a68ed80..8aa26c1f36 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -41,6 +41,7 @@ import ( "code.gitea.io/gitea/routers/web/user" user_setting "code.gitea.io/gitea/routers/web/user/setting" "code.gitea.io/gitea/routers/web/user/setting/security" + gitea_websocket "code.gitea.io/gitea/routers/web/websocket" auth_service "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/forms" @@ -588,6 +589,7 @@ func registerWebRoutes(m *web.Router, webAuth *AuthMiddleware) { }, reqSignOut) m.Any("/user/events", routing.MarkLongPolling, events.Events) + m.Get("/-/ws", gitea_websocket.Serve) m.Group("/login/oauth", func() { m.Group("", func() { diff --git a/routers/web/websocket/websocket.go b/routers/web/websocket/websocket.go new file mode 100644 index 0000000000..e0fc955cfc --- /dev/null +++ b/routers/web/websocket/websocket.go @@ -0,0 +1,53 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package websocket + +import ( + "encoding/json" + "fmt" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/services/context" + "code.gitea.io/gitea/services/pubsub" + + gitea_ws "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +// Serve handles WebSocket upgrade and event delivery for the signed-in user. +func Serve(ctx *context.Context) { + if !ctx.IsSigned { + ctx.Status(401) + return + } + + conn, err := gitea_ws.Accept(ctx.Resp, ctx.Req, &gitea_ws.AcceptOptions{ + InsecureSkipVerify: false, + }) + if err != nil { + log.Error("websocket: accept failed: %v", err) + return + } + defer conn.CloseNow() //nolint:errcheck + + topic := fmt.Sprintf("user-%d", ctx.Doer.ID) + ch, cancel := pubsub.DefaultBroker.Subscribe(topic) + defer cancel() + + wsCtx := ctx.Req.Context() + for { + select { + case <-wsCtx.Done(): + return + case msg, ok := <-ch: + if !ok { + return + } + if err := wsjson.Write(wsCtx, conn, json.RawMessage(msg)); err != nil { + log.Trace("websocket: write failed: %v", err) + return + } + } + } +} diff --git a/services/context/response.go b/services/context/response.go index c7368ebc6f..ac86820d70 100644 --- a/services/context/response.go +++ b/services/context/response.go @@ -4,6 +4,8 @@ package context import ( + "bufio" + "net" "net/http" web_types "code.gitea.io/gitea/modules/web/types" @@ -67,6 +69,15 @@ func (r *Response) WriteHeader(statusCode int) { } } +// Hijack implements http.Hijacker by forwarding to the underlying ResponseWriter. +// This is required for WebSocket upgrades. +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := r.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, http.ErrNotSupported +} + // Flush flushes cached data func (r *Response) Flush() { if f, ok := r.ResponseWriter.(http.Flusher); ok { diff --git a/services/pubsub/broker.go b/services/pubsub/broker.go new file mode 100644 index 0000000000..c2f2dc1026 --- /dev/null +++ b/services/pubsub/broker.go @@ -0,0 +1,65 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package pubsub + +import ( + "sync" +) + +// Broker is a simple in-memory pub/sub broker. +// It supports fan-out: one Publish call delivers the message to all active subscribers. +type Broker struct { + mu sync.RWMutex + subs map[string][]chan []byte +} + +// DefaultBroker is the global singleton used by both routers and notifiers. +var DefaultBroker = NewBroker() + +// NewBroker creates a new in-memory Broker. +func NewBroker() *Broker { + return &Broker{ + subs: make(map[string][]chan []byte), + } +} + +// Subscribe returns a channel that receives messages published to topic. +// Call the returned cancel function to unsubscribe. +func (b *Broker) Subscribe(topic string) (<-chan []byte, func()) { + ch := make(chan []byte, 8) + + b.mu.Lock() + b.subs[topic] = append(b.subs[topic], ch) + b.mu.Unlock() + + cancel := func() { + b.mu.Lock() + defer b.mu.Unlock() + subs := b.subs[topic] + for i, sub := range subs { + if sub == ch { + b.subs[topic] = append(subs[:i], subs[i+1:]...) + break + } + } + close(ch) + } + return ch, cancel +} + +// Publish sends msg to all subscribers of topic. +// Non-blocking: slow subscribers are skipped. +func (b *Broker) Publish(topic string, msg []byte) { + b.mu.RLock() + subs := b.subs[topic] + b.mu.RUnlock() + + for _, ch := range subs { + select { + case ch <- msg: + default: + // subscriber too slow — skip + } + } +} diff --git a/services/websocket/notifier.go b/services/websocket/notifier.go new file mode 100644 index 0000000000..85c98bfb7b --- /dev/null +++ b/services/websocket/notifier.go @@ -0,0 +1,76 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package websocket + +import ( + "context" + "fmt" + "time" + + activities_model "code.gitea.io/gitea/models/activities" + "code.gitea.io/gitea/modules/graceful" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/process" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/services/pubsub" +) + +type notificationCountEvent struct { + Type string `json:"type"` + Count int64 `json:"count"` +} + +func userTopic(userID int64) string { + return fmt.Sprintf("user-%d", userID) +} + +// Init starts the background goroutine that polls notification counts +// and pushes updates to connected WebSocket clients. +func Init() error { + go graceful.GetManager().RunWithShutdownContext(run) + return nil +} + +func run(ctx context.Context) { + ctx, _, finished := process.GetManager().AddTypedContext(ctx, "Service: WebSocket", process.SystemProcessType, true) + defer finished() + + if setting.UI.Notification.EventSourceUpdateTime <= 0 { + return + } + + then := timeutil.TimeStampNow().Add(-2) + timer := time.NewTicker(setting.UI.Notification.EventSourceUpdateTime) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + now := timeutil.TimeStampNow().Add(-2) + + uidCounts, err := activities_model.GetUIDsAndNotificationCounts(ctx, then, now) + if err != nil { + log.Error("websocket: GetUIDsAndNotificationCounts: %v", err) + continue + } + + for _, uidCount := range uidCounts { + msg, err := json.Marshal(notificationCountEvent{ + Type: "notification-count", + Count: uidCount.Count, + }) + if err != nil { + continue + } + pubsub.DefaultBroker.Publish(userTopic(uidCount.UserID), msg) + } + + then = now + } + } +} diff --git a/web_src/js/features/notification.ts b/web_src/js/features/notification.ts index 915f65f88d..e31fd5231e 100644 --- a/web_src/js/features/notification.ts +++ b/web_src/js/features/notification.ts @@ -5,12 +5,12 @@ import {logoutFromWorker} from '../modules/worker.ts'; const {appSubUrl, notificationSettings, assetVersionEncoded} = window.config; let notificationSequenceNumber = 0; -async function receiveUpdateCount(event: MessageEvent<{type: string, data: string}>) { +async function receiveUpdateCount(event: MessageEvent<{type: string, count: number}>) { try { - const data = JSON.parse(event.data.data); - for (const count of document.querySelectorAll('.notification_count')) { - count.classList.toggle('tw-hidden', data.Count === 0); - count.textContent = `${data.Count}`; + const {count} = event.data; + for (const el of document.querySelectorAll('.notification_count')) { + el.classList.toggle('tw-hidden', count === 0); + el.textContent = `${count}`; } await updateNotificationTable(); } catch (error) { @@ -21,55 +21,38 @@ async function receiveUpdateCount(event: MessageEvent<{type: string, data: strin export function initNotificationCount() { if (!document.querySelector('.notification_count')) return; - let usingPeriodicPoller = false; const startPeriodicPoller = (timeout: number, lastCount?: number) => { if (timeout <= 0 || !Number.isFinite(timeout)) return; - usingPeriodicPoller = true; lastCount = lastCount ?? getCurrentCount(); setTimeout(async () => { await updateNotificationCountWithCallback(startPeriodicPoller, timeout, lastCount); }, timeout); }; - if (notificationSettings.EventSourceUpdateTime > 0 && window.EventSource && window.SharedWorker) { - // Try to connect to the event source via the shared worker first - const worker = new SharedWorker(`${window.__webpack_public_path__}js/eventsource.sharedworker.js?v=${assetVersionEncoded}`, 'notification-worker'); + if (notificationSettings.EventSourceUpdateTime > 0 && window.SharedWorker) { + // Connect via WebSocket SharedWorker (one connection shared across all tabs) + const wsUrl = `${window.location.origin}${appSubUrl}/-/ws`.replace(/^http/, 'ws'); + const worker = new SharedWorker(`${window.__webpack_public_path__}js/websocket.sharedworker.js?v=${assetVersionEncoded}`, 'notification-worker'); worker.addEventListener('error', (event) => { console.error('worker error', event); }); worker.port.addEventListener('messageerror', () => { console.error('unable to deserialize message'); }); - worker.port.postMessage({ - type: 'start', - url: `${window.location.origin}${appSubUrl}/user/events`, - }); - worker.port.addEventListener('message', (event: MessageEvent<{type: string, data: string}>) => { + worker.port.postMessage({type: 'start', url: wsUrl}); + worker.port.addEventListener('message', (event: MessageEvent<{type: string, count: number, message?: string}>) => { if (!event.data || !event.data.type) { console.error('unknown worker message event', event); return; } if (event.data.type === 'notification-count') { receiveUpdateCount(event); // no await - } else if (event.data.type === 'no-event-source') { - // browser doesn't support EventSource, falling back to periodic poller - if (!usingPeriodicPoller) startPeriodicPoller(notificationSettings.MinTimeout); } else if (event.data.type === 'error') { console.error('worker port event error', event.data); } else if (event.data.type === 'logout') { - if (event.data.data !== 'here') { - return; - } - worker.port.postMessage({ - type: 'close', - }); + worker.port.postMessage({type: 'close'}); worker.port.close(); logoutFromWorker(); - } else if (event.data.type === 'close') { - worker.port.postMessage({ - type: 'close', - }); - worker.port.close(); } }); worker.port.addEventListener('error', (e) => { @@ -77,9 +60,7 @@ export function initNotificationCount() { }); worker.port.start(); window.addEventListener('beforeunload', () => { - worker.port.postMessage({ - type: 'close', - }); + worker.port.postMessage({type: 'close'}); worker.port.close(); }); diff --git a/web_src/js/features/websocket.sharedworker.ts b/web_src/js/features/websocket.sharedworker.ts new file mode 100644 index 0000000000..88d4870f01 --- /dev/null +++ b/web_src/js/features/websocket.sharedworker.ts @@ -0,0 +1,144 @@ +// One WebSocket connection per URL, shared across all tabs via SharedWorker. +// Messages from the server are JSON objects broadcast to all connected ports. +export {}; // make this a module to avoid global scope conflicts with other sharedworker files + +const RECONNECT_DELAY_INITIAL = 50; +const RECONNECT_DELAY_MAX = 10000; + +class WsSource { + url: string; + ws: WebSocket | null; + clients: MessagePort[]; + reconnectTimer: ReturnType | null; + reconnectDelay: number; + + constructor(url: string) { + this.url = url; + this.ws = null; + this.clients = []; + this.reconnectTimer = null; + this.reconnectDelay = RECONNECT_DELAY_INITIAL; + this.connect(); + } + + connect() { + this.ws = new WebSocket(this.url); + + this.ws.addEventListener('open', () => { + this.reconnectDelay = RECONNECT_DELAY_INITIAL; + this.broadcast({type: 'status', message: `connected to ${this.url}`}); + }); + + this.ws.addEventListener('message', (event: MessageEvent) => { + try { + const msg = JSON.parse(event.data); + this.broadcast(msg); + } catch { + // ignore malformed JSON + } + }); + + this.ws.addEventListener('close', () => { + this.ws = null; + this.scheduleReconnect(); + }); + + this.ws.addEventListener('error', () => { + this.broadcast({type: 'error', message: 'websocket error'}); + this.ws = null; + this.scheduleReconnect(); + }); + } + + scheduleReconnect() { + if (this.clients.length === 0 || this.reconnectTimer !== null) return; + this.reconnectDelay = Math.min(this.reconnectDelay * 2, RECONNECT_DELAY_MAX); + this.reconnectTimer = setTimeout(() => { + this.reconnectTimer = null; + this.connect(); + }, this.reconnectDelay); + } + + register(port: MessagePort) { + if (this.clients.includes(port)) return; + this.clients.push(port); + port.postMessage({type: 'status', message: `registered to ${this.url}`}); + } + + deregister(port: MessagePort): number { + const idx = this.clients.indexOf(port); + if (idx >= 0) this.clients.splice(idx, 1); + return this.clients.length; + } + + close() { + if (this.reconnectTimer !== null) { + clearTimeout(this.reconnectTimer); + this.reconnectTimer = null; + } + this.ws?.close(); + this.ws = null; + } + + broadcast(msg: unknown) { + for (const port of this.clients) { + port.postMessage(msg); + } + } +} + +const sourcesByUrl = new Map(); +const sourcesByPort = new Map(); + +(self as unknown as SharedWorkerGlobalScope).addEventListener('connect', (e: MessageEvent) => { + for (const port of e.ports) { + port.addEventListener('message', (event: MessageEvent) => { + if (event.data.type === 'start') { + const {url} = event.data; + let source = sourcesByUrl.get(url); + if (source) { + source.register(port); + sourcesByPort.set(port, source); + return; + } + source = sourcesByPort.get(port); + if (source) { + const count = source.deregister(port); + if (count === 0) { + source.close(); + sourcesByUrl.set(source.url, null); + } + } + source = new WsSource(url); + source.register(port); + sourcesByUrl.set(url, source); + sourcesByPort.set(port, source); + } else if (event.data.type === 'close') { + const source = sourcesByPort.get(port); + if (!source) return; + const count = source.deregister(port); + if (count === 0) { + source.close(); + sourcesByUrl.set(source.url, null); + sourcesByPort.set(port, null); + } + } else if (event.data.type === 'status') { + const source = sourcesByPort.get(port); + if (!source) { + port.postMessage({type: 'status', message: 'not connected'}); + return; + } + port.postMessage({ + type: 'status', + message: `url: ${source.url} readyState: ${source.ws?.readyState ?? 'null'}`, + }); + } else { + port.postMessage({ + type: 'error', + message: `received but don't know how to handle: ${JSON.stringify(event.data)}`, + }); + } + }); + port.start(); + } +}); diff --git a/webpack.config.ts b/webpack.config.ts index e3ef996909..cd601d6653 100644 --- a/webpack.config.ts +++ b/webpack.config.ts @@ -79,6 +79,9 @@ export default { 'eventsource.sharedworker': [ fileURLToPath(new URL('web_src/js/features/eventsource.sharedworker.ts', import.meta.url)), ], + 'websocket.sharedworker': [ + fileURLToPath(new URL('web_src/js/features/websocket.sharedworker.ts', import.meta.url)), + ], ...(!isProduction && { devtest: [ fileURLToPath(new URL('web_src/js/standalone/devtest.ts', import.meta.url)),