diff --git a/routers/init.go b/routers/init.go index 2ed7a57e5c..f6775dd8fe 100644 --- a/routers/init.go +++ b/routers/init.go @@ -54,6 +54,7 @@ import ( "code.gitea.io/gitea/services/task" "code.gitea.io/gitea/services/uinotification" "code.gitea.io/gitea/services/webhook" + websocket_service "code.gitea.io/gitea/services/websocket" ) func mustInit(fn func() error) { @@ -160,6 +161,7 @@ func InitWebInstalled(ctx context.Context) { mustInit(task.Init) mustInit(repo_migrations.Init) eventsource.GetManager().Init() + mustInit(websocket_service.Init) mustInitCtx(ctx, mailer_incoming.Init) mustInitCtx(ctx, syncAppConfForGit) diff --git a/routers/web/web.go b/routers/web/web.go index 72d2c27eaf..6cf209a886 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -41,6 +41,7 @@ import ( "code.gitea.io/gitea/routers/web/user" user_setting "code.gitea.io/gitea/routers/web/user/setting" "code.gitea.io/gitea/routers/web/user/setting/security" + gitea_websocket "code.gitea.io/gitea/routers/web/websocket" auth_service "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/forms" @@ -592,6 +593,7 @@ func registerWebRoutes(m *web.Router, webAuth *AuthMiddleware) { }, reqSignOut) m.Any("/user/events", routing.MarkLongPolling, events.Events) + m.Get("/-/ws", gitea_websocket.Serve) m.Group("/login/oauth", func() { m.Group("", func() { diff --git a/routers/web/websocket/websocket.go b/routers/web/websocket/websocket.go new file mode 100644 index 0000000000..e0fc955cfc --- /dev/null +++ b/routers/web/websocket/websocket.go @@ -0,0 +1,53 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package websocket + +import ( + "encoding/json" + "fmt" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/services/context" + "code.gitea.io/gitea/services/pubsub" + + gitea_ws "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +// Serve handles WebSocket upgrade and event delivery for the signed-in user. +func Serve(ctx *context.Context) { + if !ctx.IsSigned { + ctx.Status(401) + return + } + + conn, err := gitea_ws.Accept(ctx.Resp, ctx.Req, &gitea_ws.AcceptOptions{ + InsecureSkipVerify: false, + }) + if err != nil { + log.Error("websocket: accept failed: %v", err) + return + } + defer conn.CloseNow() //nolint:errcheck + + topic := fmt.Sprintf("user-%d", ctx.Doer.ID) + ch, cancel := pubsub.DefaultBroker.Subscribe(topic) + defer cancel() + + wsCtx := ctx.Req.Context() + for { + select { + case <-wsCtx.Done(): + return + case msg, ok := <-ch: + if !ok { + return + } + if err := wsjson.Write(wsCtx, conn, json.RawMessage(msg)); err != nil { + log.Trace("websocket: write failed: %v", err) + return + } + } + } +} diff --git a/services/pubsub/broker.go b/services/pubsub/broker.go new file mode 100644 index 0000000000..c2f2dc1026 --- /dev/null +++ b/services/pubsub/broker.go @@ -0,0 +1,65 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package pubsub + +import ( + "sync" +) + +// Broker is a simple in-memory pub/sub broker. +// It supports fan-out: one Publish call delivers the message to all active subscribers. +type Broker struct { + mu sync.RWMutex + subs map[string][]chan []byte +} + +// DefaultBroker is the global singleton used by both routers and notifiers. +var DefaultBroker = NewBroker() + +// NewBroker creates a new in-memory Broker. +func NewBroker() *Broker { + return &Broker{ + subs: make(map[string][]chan []byte), + } +} + +// Subscribe returns a channel that receives messages published to topic. +// Call the returned cancel function to unsubscribe. +func (b *Broker) Subscribe(topic string) (<-chan []byte, func()) { + ch := make(chan []byte, 8) + + b.mu.Lock() + b.subs[topic] = append(b.subs[topic], ch) + b.mu.Unlock() + + cancel := func() { + b.mu.Lock() + defer b.mu.Unlock() + subs := b.subs[topic] + for i, sub := range subs { + if sub == ch { + b.subs[topic] = append(subs[:i], subs[i+1:]...) + break + } + } + close(ch) + } + return ch, cancel +} + +// Publish sends msg to all subscribers of topic. +// Non-blocking: slow subscribers are skipped. +func (b *Broker) Publish(topic string, msg []byte) { + b.mu.RLock() + subs := b.subs[topic] + b.mu.RUnlock() + + for _, ch := range subs { + select { + case ch <- msg: + default: + // subscriber too slow — skip + } + } +} diff --git a/services/websocket/notifier.go b/services/websocket/notifier.go new file mode 100644 index 0000000000..85c98bfb7b --- /dev/null +++ b/services/websocket/notifier.go @@ -0,0 +1,76 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package websocket + +import ( + "context" + "fmt" + "time" + + activities_model "code.gitea.io/gitea/models/activities" + "code.gitea.io/gitea/modules/graceful" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/process" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/services/pubsub" +) + +type notificationCountEvent struct { + Type string `json:"type"` + Count int64 `json:"count"` +} + +func userTopic(userID int64) string { + return fmt.Sprintf("user-%d", userID) +} + +// Init starts the background goroutine that polls notification counts +// and pushes updates to connected WebSocket clients. +func Init() error { + go graceful.GetManager().RunWithShutdownContext(run) + return nil +} + +func run(ctx context.Context) { + ctx, _, finished := process.GetManager().AddTypedContext(ctx, "Service: WebSocket", process.SystemProcessType, true) + defer finished() + + if setting.UI.Notification.EventSourceUpdateTime <= 0 { + return + } + + then := timeutil.TimeStampNow().Add(-2) + timer := time.NewTicker(setting.UI.Notification.EventSourceUpdateTime) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + now := timeutil.TimeStampNow().Add(-2) + + uidCounts, err := activities_model.GetUIDsAndNotificationCounts(ctx, then, now) + if err != nil { + log.Error("websocket: GetUIDsAndNotificationCounts: %v", err) + continue + } + + for _, uidCount := range uidCounts { + msg, err := json.Marshal(notificationCountEvent{ + Type: "notification-count", + Count: uidCount.Count, + }) + if err != nil { + continue + } + pubsub.DefaultBroker.Publish(userTopic(uidCount.UserID), msg) + } + + then = now + } + } +} diff --git a/web_src/js/features/websocket.sharedworker.ts b/web_src/js/features/websocket.sharedworker.ts new file mode 100644 index 0000000000..88d4870f01 --- /dev/null +++ b/web_src/js/features/websocket.sharedworker.ts @@ -0,0 +1,144 @@ +// One WebSocket connection per URL, shared across all tabs via SharedWorker. +// Messages from the server are JSON objects broadcast to all connected ports. +export {}; // make this a module to avoid global scope conflicts with other sharedworker files + +const RECONNECT_DELAY_INITIAL = 50; +const RECONNECT_DELAY_MAX = 10000; + +class WsSource { + url: string; + ws: WebSocket | null; + clients: MessagePort[]; + reconnectTimer: ReturnType | null; + reconnectDelay: number; + + constructor(url: string) { + this.url = url; + this.ws = null; + this.clients = []; + this.reconnectTimer = null; + this.reconnectDelay = RECONNECT_DELAY_INITIAL; + this.connect(); + } + + connect() { + this.ws = new WebSocket(this.url); + + this.ws.addEventListener('open', () => { + this.reconnectDelay = RECONNECT_DELAY_INITIAL; + this.broadcast({type: 'status', message: `connected to ${this.url}`}); + }); + + this.ws.addEventListener('message', (event: MessageEvent) => { + try { + const msg = JSON.parse(event.data); + this.broadcast(msg); + } catch { + // ignore malformed JSON + } + }); + + this.ws.addEventListener('close', () => { + this.ws = null; + this.scheduleReconnect(); + }); + + this.ws.addEventListener('error', () => { + this.broadcast({type: 'error', message: 'websocket error'}); + this.ws = null; + this.scheduleReconnect(); + }); + } + + scheduleReconnect() { + if (this.clients.length === 0 || this.reconnectTimer !== null) return; + this.reconnectDelay = Math.min(this.reconnectDelay * 2, RECONNECT_DELAY_MAX); + this.reconnectTimer = setTimeout(() => { + this.reconnectTimer = null; + this.connect(); + }, this.reconnectDelay); + } + + register(port: MessagePort) { + if (this.clients.includes(port)) return; + this.clients.push(port); + port.postMessage({type: 'status', message: `registered to ${this.url}`}); + } + + deregister(port: MessagePort): number { + const idx = this.clients.indexOf(port); + if (idx >= 0) this.clients.splice(idx, 1); + return this.clients.length; + } + + close() { + if (this.reconnectTimer !== null) { + clearTimeout(this.reconnectTimer); + this.reconnectTimer = null; + } + this.ws?.close(); + this.ws = null; + } + + broadcast(msg: unknown) { + for (const port of this.clients) { + port.postMessage(msg); + } + } +} + +const sourcesByUrl = new Map(); +const sourcesByPort = new Map(); + +(self as unknown as SharedWorkerGlobalScope).addEventListener('connect', (e: MessageEvent) => { + for (const port of e.ports) { + port.addEventListener('message', (event: MessageEvent) => { + if (event.data.type === 'start') { + const {url} = event.data; + let source = sourcesByUrl.get(url); + if (source) { + source.register(port); + sourcesByPort.set(port, source); + return; + } + source = sourcesByPort.get(port); + if (source) { + const count = source.deregister(port); + if (count === 0) { + source.close(); + sourcesByUrl.set(source.url, null); + } + } + source = new WsSource(url); + source.register(port); + sourcesByUrl.set(url, source); + sourcesByPort.set(port, source); + } else if (event.data.type === 'close') { + const source = sourcesByPort.get(port); + if (!source) return; + const count = source.deregister(port); + if (count === 0) { + source.close(); + sourcesByUrl.set(source.url, null); + sourcesByPort.set(port, null); + } + } else if (event.data.type === 'status') { + const source = sourcesByPort.get(port); + if (!source) { + port.postMessage({type: 'status', message: 'not connected'}); + return; + } + port.postMessage({ + type: 'status', + message: `url: ${source.url} readyState: ${source.ws?.readyState ?? 'null'}`, + }); + } else { + port.postMessage({ + type: 'error', + message: `received but don't know how to handle: ${JSON.stringify(event.data)}`, + }); + } + }); + port.start(); + } +});