diff --git a/modules/google/groups.go b/modules/google/groups.go new file mode 100644 index 0000000000..b8a3af159d --- /dev/null +++ b/modules/google/groups.go @@ -0,0 +1,138 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package google + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + + "github.com/markbates/goth" +) + +const ( + IAMScope = "https://www.googleapis.com/auth/cloud-identity.groups.readonly" + defaultIAMGroupsEndpoint = "https://content-cloudidentity.googleapis.com/v1/groups/-/memberships:searchDirectGroups" +) + +// maxGroupPages is the maximum number of pages fetched from the Google +// Cloud Identity API. The API returns up to 200 groups per page by default, +// so this caps group membership at 4,000 groups per user — far beyond any +// realistic Google Workspace organization. +const maxGroupPages = 20 + +// Client calls Google Workspace APIs. +type Client struct { + httpClient *http.Client + groupsEndpoint string + claimName string + failLoginOnAdditionalInfoError bool +} + +// NewClient creates a Client using the given authenticated HTTP client. +// The client should be built from an OAuth2 token carrying IAMScope. +func NewClient(httpClient *http.Client, claimName string, failLoginOnAdditionalInfoError bool) *Client { + return &Client{ + httpClient: httpClient, + groupsEndpoint: defaultIAMGroupsEndpoint, + claimName: claimName, + failLoginOnAdditionalInfoError: failLoginOnAdditionalInfoError, + } +} + +// groupMembership represents a single membership entry returned by the +// Cloud Identity Groups API searchDirectGroups endpoint. +type groupMembership struct { + GroupKey struct { + ID string `json:"id"` + } `json:"groupKey"` +} + +// groupsResponse is the paged response from the Cloud Identity API. +type groupsResponse struct { + Memberships []groupMembership `json:"memberships"` + NextPageToken string `json:"nextPageToken"` +} + +// FetchGroups queries the Google Cloud Identity Groups API for all +// groups the given user (identified by email) is a direct member of. +// The caller must supply an HTTP client already authenticated with an access +// token that carries the IAMScope scope. +func (c *Client) FetchGroups(ctx context.Context, email string) ([]string, error) { + groups := make([]string, 0, 16) + pageToken := "" + + for range make([]struct{}, maxGroupPages) { + params := url.Values{} + params.Set("query", fmt.Sprintf("member_key_id=='%s'", email)) + if pageToken != "" { + params.Set("pageToken", pageToken) + } + apiURL := c.groupsEndpoint + "?" + params.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return nil, fmt.Errorf("google groups: build request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("google groups: HTTP request: %w", err) + } + body, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + if readErr != nil { + return nil, fmt.Errorf("google groups: read response: %w", readErr) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("google groups: API returned %d: %s", resp.StatusCode, body) + } + + var page groupsResponse + if err := json.Unmarshal(body, &page); err != nil { + return nil, fmt.Errorf("google groups: decode response: %w", err) + } + + for _, m := range page.Memberships { + if m.GroupKey.ID != "" { + groups = append(groups, m.GroupKey.ID) + } + } + + if page.NextPageToken == "" { + break + } + pageToken = page.NextPageToken + } + + return groups, nil +} + +// FetchAdditionalInfo implements oauth2.AdditionalInfoProvider. +// It fetches Google Workspace group memberships and injects them into +// gothUser.RawData under the given claimName key. +func (c *Client) FetchAdditionalInfo(ctx context.Context, user goth.User) (goth.User, error) { + groups, err := c.FetchGroups(ctx, user.Email) + if err != nil { + return user, err + } + if user.RawData == nil { + user.RawData = make(map[string]any) + } + if existing, has := user.RawData[c.claimName]; has { + log.Warn("OAuth2 Google: RawData already contains claim %q with some value. Consider to use different claim name for groups information", c.claimName, existing) + } + user.RawData[c.claimName] = groups + return user, nil +} + +// FailLoginOnAdditionalInfoError implements oauth2.AdditionalInfoProvider. +func (c *Client) FailLoginOnAdditionalInfoError() bool { + return c.failLoginOnAdditionalInfoError +} diff --git a/modules/google/groups_test.go b/modules/google/groups_test.go new file mode 100644 index 0000000000..d5bc83567d --- /dev/null +++ b/modules/google/groups_test.go @@ -0,0 +1,174 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package google + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/markbates/goth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestClient(t *testing.T, server *httptest.Server) *Client { + t.Helper() + c := NewClient(server.Client(), "groups", false) + c.groupsEndpoint = server.URL + return c +} + +func mockGroupsServer(t *testing.T, expectedEmail string, pages [][]string) *httptest.Server { + t.Helper() + pageIndex := 0 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Validate the query contains the correct member_key_id + expectedQuery := fmt.Sprintf("member_key_id=='%s'", expectedEmail) + assert.Equal(t, expectedQuery, r.URL.Query().Get("query")) + + var memberships []string + for _, g := range pages[pageIndex] { + memberships = append(memberships, fmt.Sprintf(`{"groupKey":{"id":%q}}`, g)) + } + + nextPageToken := "" + if pageIndex < len(pages)-1 { + nextPageToken = "page-token" + } + pageIndex++ + + body := fmt.Sprintf( + `{"memberships":[%s],"nextPageToken":%q}`, + strings.Join(memberships, ","), + nextPageToken, + ) + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprint(w, body) + })) +} + +func TestFetchGoogleGroups_SinglePage(t *testing.T) { + server := mockGroupsServer(t, "user@example.com", [][]string{ + {"group-a@example.com", "group-b@example.com"}, + }) + defer server.Close() + + client := newTestClient(t, server) + groups, err := client.FetchGroups(context.Background(), "user@example.com") + require.NoError(t, err) + assert.ElementsMatch(t, []string{"group-a@example.com", "group-b@example.com"}, groups) +} + +func TestFetchGroups_SinglePage(t *testing.T) { + server := mockGroupsServer(t, "user@example.com", [][]string{ + {"group-a@example.com", "group-b@example.com"}, + }) + defer server.Close() + + client := newTestClient(t, server) + groups, err := client.FetchGroups(context.Background(), "user@example.com") + require.NoError(t, err) + assert.ElementsMatch(t, []string{"group-a@example.com", "group-b@example.com"}, groups) +} + +func TestFetchGoogleGroups_MultiPage(t *testing.T) { + server := mockGroupsServer(t, "user@example.com", [][]string{ + {"group-a@example.com"}, + {"group-b@example.com", "group-c@example.com"}, + }) + defer server.Close() + + client := newTestClient(t, server) + groups, err := client.FetchGroups(context.Background(), "user@example.com") + require.NoError(t, err) + assert.ElementsMatch(t, []string{"group-a@example.com", "group-b@example.com", "group-c@example.com"}, groups) +} + +func TestFetchGoogleGroups_Empty(t *testing.T) { + server := mockGroupsServer(t, "user@example.com", [][]string{{}}) + defer server.Close() + + client := newTestClient(t, server) + groups, err := client.FetchGroups(context.Background(), "user@example.com") + require.NoError(t, err) + assert.Empty(t, groups) +} + +func TestFetchGoogleGroups_APIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = fmt.Fprint(w, `{"error":"forbidden"}`) + })) + defer server.Close() + + client := newTestClient(t, server) + groups, err := client.FetchGroups(context.Background(), "user@example.com") + require.Error(t, err) + assert.Nil(t, groups) + assert.Contains(t, err.Error(), "403") +} + +func TestFetchGoogleGroups_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `not valid json`) + })) + defer server.Close() + + client := newTestClient(t, server) + groups, err := client.FetchGroups(context.Background(), "user@example.com") + require.Error(t, err) + assert.Nil(t, groups) +} + +func TestFetchAdditionalInfo_InjectsClaimBeforeValidation(t *testing.T) { + server := mockGroupsServer(t, "user@example.com", [][]string{ + {"required-group@example.com"}, + }) + defer server.Close() + + c := newTestClient(t, server) + c.claimName = "groups" + + user := goth.User{ + Email: "user@example.com", + RawData: map[string]any{}, + } + + enriched, err := c.FetchAdditionalInfo(context.Background(), user) + require.NoError(t, err) + + // Verify the claim is present and contains the group — simulating what + // RequiredClaimName validation would check after enrichment runs. + groups, ok := enriched.RawData["groups"] + require.True(t, ok, "groups claim must be present in RawData after enrichment") + groupSlice, ok := groups.([]string) + require.True(t, ok) + assert.Contains(t, groupSlice, "required-group@example.com") +} + +func TestFetchAdditionalInfo_ErrorDoesNotInjectClaim(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = fmt.Fprint(w, `{"error":"forbidden"}`) + })) + defer server.Close() + + c := newTestClient(t, server) + c.claimName = "groups" + + user := goth.User{ + Email: "user@example.com", + RawData: map[string]any{}, + } + + enriched, err := c.FetchAdditionalInfo(context.Background(), user) + require.Error(t, err) + _, hasGroups := enriched.RawData["groups"] + assert.False(t, hasGroups) +} diff --git a/options/locale/locale_en-US.json b/options/locale/locale_en-US.json index c7ec133e57..1e14f4d239 100644 --- a/options/locale/locale_en-US.json +++ b/options/locale/locale_en-US.json @@ -428,6 +428,8 @@ "auth.oauth.signin.error.general": "There was an error processing the authorization request: %s. If this error persists, please contact the site administrator.", "auth.oauth.signin.error.access_denied": "The authorization request was denied.", "auth.oauth.signin.error.temporarily_unavailable": "Authorization failed because the authentication server is temporarily unavailable. Please try again later.", + "auth.oauth.required_additional_info_fetch_failed_banner.title": "OAuth2 additional information synchronization warning", + "auth.oauth.required_additional_info_fetch_failed_banner.desc": "Gitea recently failed to retrieve required additional information from authentication source \"%s\" (last failure: %s). Users may keep their previously synchronized admin/restricted/team state until retrieval succeeds again. Check logs for details.", "auth.oauth_callback_unable_auto_reg": "Auto Registration is enabled, but OAuth2 Provider %[1]s returned missing fields: %[2]s, unable to create an account automatically. Please create or link to an account, or contact the site administrator.", "auth.openid_connect_submit": "Connect", "auth.openid_connect_title": "Connect to an existing account", diff --git a/routers/common/pagetmpl.go b/routers/common/pagetmpl.go index c48596d48b..af86fb5038 100644 --- a/routers/common/pagetmpl.go +++ b/routers/common/pagetmpl.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/models/db" issues_model "code.gitea.io/gitea/models/issues" "code.gitea.io/gitea/modules/log" + oauth2_source "code.gitea.io/gitea/services/auth/source/oauth2" "code.gitea.io/gitea/services/context" ) @@ -69,8 +70,16 @@ type pageGlobalDataType struct { IsSigned bool IsSiteAdmin bool - GetNotificationUnreadCount func() int64 - GetActiveStopwatch func() *StopwatchTmplInfo + GetNotificationUnreadCount func() int64 + GetActiveStopwatch func() *StopwatchTmplInfo + GetRequiredAdditionalInfoFailureWarning func() *oauth2_source.RequiredAdditionalInfoFailureWarning +} + +func oauth2RequiredAdditionalInfoFailureWarning(ctx *context.Context) *oauth2_source.RequiredAdditionalInfoFailureWarning { + if ctx.Doer == nil || !ctx.Doer.IsAdmin { + return nil + } + return oauth2_source.GetRequiredAdditionalInfoFailureWarning(ctx.Cache) } func PageGlobalData(ctx *context.Context) { @@ -79,5 +88,8 @@ func PageGlobalData(ctx *context.Context) { data.IsSiteAdmin = ctx.Doer != nil && ctx.Doer.IsAdmin data.GetNotificationUnreadCount = sync.OnceValue(func() int64 { return notificationUnreadCount(ctx) }) data.GetActiveStopwatch = sync.OnceValue(func() *StopwatchTmplInfo { return getActiveStopwatch(ctx) }) + data.GetRequiredAdditionalInfoFailureWarning = sync.OnceValue(func() *oauth2_source.RequiredAdditionalInfoFailureWarning { + return oauth2RequiredAdditionalInfoFailureWarning(ctx) + }) ctx.Data["PageGlobalData"] = data } diff --git a/routers/common/pagetmpl_test.go b/routers/common/pagetmpl_test.go new file mode 100644 index 0000000000..bf81957233 --- /dev/null +++ b/routers/common/pagetmpl_test.go @@ -0,0 +1,41 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package common + +import ( + "testing" + "time" + + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/cache" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + oauth2_source "code.gitea.io/gitea/services/auth/source/oauth2" + "code.gitea.io/gitea/services/context" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOAuth2RequiredAdditionalInfoFailureWarning_AdminOnlyVisibility(t *testing.T) { + c, err := cache.NewStringCache(setting.Cache{Adapter: "memory", Interval: 1}) + require.NoError(t, err) + + defer timeutil.MockSet(time.Unix(1_700_000_000, 0))() + oauth2_source.SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace") + + adminCtx := &context.Context{ + Doer: &user_model.User{IsAdmin: true}, + Cache: c, + } + nonAdminCtx := &context.Context{ + Doer: &user_model.User{IsAdmin: false}, + Cache: c, + } + + adminWarning := oauth2RequiredAdditionalInfoFailureWarning(adminCtx) + require.NotNil(t, adminWarning) + assert.Equal(t, "Google Workspace", adminWarning.SourceName) + assert.Nil(t, oauth2RequiredAdditionalInfoFailureWarning(nonAdminCtx)) +} diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 8645aedbde..3cbffc92e8 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -232,6 +232,10 @@ func claimValueToStringSet(claimValue any) container.Set[string] { } func syncGroupsToTeams(ctx *context.Context, source *oauth2.Source, gothUser *goth.User, u *user_model.User) error { + if !shouldSyncFromGroupClaim(source, gothUser) { + return nil + } + if source.GroupTeamMap != "" || source.GroupTeamMapRemoval { groupTeamMapping, err := auth_module.UnmarshalGroupTeamMapping(source.GroupTeamMap) if err != nil { @@ -257,7 +261,22 @@ func getClaimedGroups(source *oauth2.Source, gothUser *goth.User) container.Set[ return claimValueToStringSet(groupClaims) } +func shouldSyncFromGroupClaim(source *oauth2.Source, gothUser *goth.User) bool { + // Keep historical behavior for all providers except Google Workspace: + // if the claim is missing, group-derived sync still runs on an empty set. + if source.Provider != "gplus" { + return true + } + + _, hasGroupClaim := gothUser.RawData[source.GroupClaimName] + return hasGroupClaim +} + func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *goth.User) (isAdmin optional.Option[user_service.UpdateOptionField[bool]], isRestricted optional.Option[bool]) { + if !shouldSyncFromGroupClaim(source, gothUser) { + return isAdmin, isRestricted + } + groups := getClaimedGroups(source, gothUser) if source.AdminGroup != "" { @@ -461,6 +480,32 @@ func oAuth2UserLoginCallback(ctx *context.Context, authSource *auth.Source, requ return nil, goth.User{}, err } + // Enrichment must run before RequiredClaimName validation so that claims + // injected by the provider (e.g. Google Workspace groups fetched via the + // Cloud Identity API) are available when the required-claim check executes. + // Moving this block after the RequiredClaimName check would cause users to + // be incorrectly rejected when RequiredClaimName references an injected claim. + if provider := oauth2.GetAdditionalInfoProvider(oauth2Source, &gothUser); provider != nil { + enriched, err := provider.FetchAdditionalInfo(ctx, gothUser) + if err != nil { + log.Warn("OAuth2: failed to fetch additional info for %s: %v", gothUser.Email, err) + if provider.FailLoginOnAdditionalInfoError() { + sourceName := authSource.Name + if sourceName == "" { + sourceName = fmt.Sprintf("source #%d", authSource.ID) + } + oauth2.SetRequiredAdditionalInfoFetchFailureWarning(ctx.Cache, sourceName) + // Fail closed only when login directly depends on the group claim + // (for example RequiredClaimName == GroupClaimName). Other + // group-based sync features are fail-open and preserve prior state. + return nil, goth.User{}, user_model.ErrUserProhibitLogin{Name: gothUser.UserID} + } + } else { + oauth2.ClearRequiredAdditionalInfoFetchFailureWarning(ctx.Cache) + gothUser = enriched + } + } + if oauth2Source.RequiredClaimName != "" { claimInterface, has := gothUser.RawData[oauth2Source.RequiredClaimName] if !has { diff --git a/routers/web/auth/oauth_group_claims_test.go b/routers/web/auth/oauth_group_claims_test.go new file mode 100644 index 0000000000..d3e5a4f329 --- /dev/null +++ b/routers/web/auth/oauth_group_claims_test.go @@ -0,0 +1,108 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "testing" + + "code.gitea.io/gitea/modules/optional" + "code.gitea.io/gitea/services/auth/source/oauth2" + user_service "code.gitea.io/gitea/services/user" + + "github.com/markbates/goth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShouldSyncFromGroupClaim(t *testing.T) { + t.Run("google claim missing", func(t *testing.T) { + source := &oauth2.Source{ + Provider: "gplus", + GroupClaimName: "groups", + } + user := &goth.User{ + RawData: map[string]any{}, + } + assert.False(t, shouldSyncFromGroupClaim(source, user)) + }) + + t.Run("google claim present and empty", func(t *testing.T) { + source := &oauth2.Source{ + Provider: "gplus", + GroupClaimName: "groups", + } + user := &goth.User{ + RawData: map[string]any{ + "groups": []string{}, + }, + } + assert.True(t, shouldSyncFromGroupClaim(source, user)) + }) + + t.Run("non google provider keeps old behavior", func(t *testing.T) { + source := &oauth2.Source{ + Provider: "openidConnect", + GroupClaimName: "groups", + } + user := &goth.User{ + RawData: map[string]any{}, + } + assert.True(t, shouldSyncFromGroupClaim(source, user)) + }) +} + +func TestGetUserAdminAndRestrictedFromGroupClaims_GoogleMissingClaim(t *testing.T) { + source := &oauth2.Source{ + Provider: "gplus", + GroupClaimName: "groups", + AdminGroup: "g-admins@example.com", + RestrictedGroup: "g-restricted@example.com", + } + user := &goth.User{ + RawData: map[string]any{}, + } + + isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(source, user) + + assert.False(t, isAdmin.Has()) + assert.Equal(t, optional.None[bool](), isRestricted) +} + +func TestGetUserAdminAndRestrictedFromGroupClaims_GoogleEmptyClaim(t *testing.T) { + source := &oauth2.Source{ + Provider: "gplus", + GroupClaimName: "groups", + AdminGroup: "g-admins@example.com", + RestrictedGroup: "g-restricted@example.com", + } + user := &goth.User{ + RawData: map[string]any{ + "groups": []string{}, + }, + } + + isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(source, user) + + require.True(t, isAdmin.Has()) + assert.Equal(t, user_service.UpdateOptionFieldFromSync(false), isAdmin) + assert.Equal(t, optional.Some(false), isRestricted) +} + +func TestGetUserAdminAndRestrictedFromGroupClaims_NonGoogleMissingClaim(t *testing.T) { + source := &oauth2.Source{ + Provider: "openidConnect", + GroupClaimName: "groups", + AdminGroup: "g-admins@example.com", + RestrictedGroup: "g-restricted@example.com", + } + user := &goth.User{ + RawData: map[string]any{}, + } + + isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(source, user) + + require.True(t, isAdmin.Has()) + assert.Equal(t, user_service.UpdateOptionFieldFromSync(false), isAdmin) + assert.Equal(t, optional.Some(false), isRestricted) +} diff --git a/services/auth/source/oauth2/additional_info_provider.go b/services/auth/source/oauth2/additional_info_provider.go new file mode 100644 index 0000000000..fb31864aad --- /dev/null +++ b/services/auth/source/oauth2/additional_info_provider.go @@ -0,0 +1,60 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "context" + "slices" + + google_module "code.gitea.io/gitea/modules/google" + + "github.com/markbates/goth" + go_oauth2 "golang.org/x/oauth2" +) + +// AdditionalInfoProvider is implemented by OAuth2 providers that can fetch +// additional user information (such as group memberships) that is not +// included in the standard token or userinfo response. +// The provider receives the resolved goth user, and returns a modified copy +// with any extra data injected into RawData. +type AdditionalInfoProvider interface { + FetchAdditionalInfo(ctx context.Context, user goth.User) (goth.User, error) + FailLoginOnAdditionalInfoError() bool +} + +// GetAdditionalInfoProvider returns an AdditionalInfoProvider for the given +// source if that provider supports fetching additional info, or nil if none +// applies. The returned provider is already configured with an authenticated +// HTTP client built from the access token. +func GetAdditionalInfoProvider(source *Source, gothUser *goth.User) AdditionalInfoProvider { + switch source.Provider { + case "gplus": + if slices.Contains(source.Scopes, google_module.IAMScope) { + claimName := source.GroupClaimName + if claimName == "" { + claimName = "groups" + } + oauthToken := &go_oauth2.Token{AccessToken: gothUser.AccessToken} + // Note: we use only the access token without a refresh token. + // This is intentional — the token is issued moments before this + // call during the login flow and is guaranteed to be fresh. + authenticatedClient := go_oauth2.NewClient(context.Background(), go_oauth2.StaticTokenSource(oauthToken)) + return google_module.NewClient(authenticatedClient, claimName, isGoogleGroupClaimRequiredForLoginFlow(source)) + } + } + return nil +} + +func isGoogleGroupClaimRequiredForLoginFlow(source *Source) bool { + groupClaimName := source.GroupClaimName + if groupClaimName == "" { + groupClaimName = "groups" + } + + // Fail closed only when login itself depends on the group claim. + // + // Admin/restricted/team sync can preserve the user's previous state when the + // group claim is missing, so those options intentionally stay fail-open. + return source.RequiredClaimName == groupClaimName +} diff --git a/services/auth/source/oauth2/additional_info_provider_test.go b/services/auth/source/oauth2/additional_info_provider_test.go new file mode 100644 index 0000000000..65e662ec9e --- /dev/null +++ b/services/auth/source/oauth2/additional_info_provider_test.go @@ -0,0 +1,110 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "testing" + "time" + + "code.gitea.io/gitea/modules/cache" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsGoogleGroupClaimRequiredForLoginFlow(t *testing.T) { + t.Run("no group-dependent options", func(t *testing.T) { + source := &Source{ + GroupClaimName: "groups", + } + assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source)) + }) + + t.Run("required claim uses group claim", func(t *testing.T) { + source := &Source{ + GroupClaimName: "custom_groups", + RequiredClaimName: "custom_groups", + } + assert.True(t, isGoogleGroupClaimRequiredForLoginFlow(source)) + }) + + t.Run("required claim uses default groups claim", func(t *testing.T) { + source := &Source{ + RequiredClaimName: "groups", + } + assert.True(t, isGoogleGroupClaimRequiredForLoginFlow(source)) + }) + + t.Run("admin group configured", func(t *testing.T) { + source := &Source{ + AdminGroup: "admins@example.com", + } + assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source)) + }) + + t.Run("restricted group configured", func(t *testing.T) { + source := &Source{ + RestrictedGroup: "restricted@example.com", + } + assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source)) + }) + + t.Run("group team mapping configured", func(t *testing.T) { + source := &Source{ + GroupTeamMap: "{\"a\": {\"org\": [\"team\"]}}", + } + assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source)) + }) + + t.Run("group team mapping removal enabled", func(t *testing.T) { + source := &Source{ + GroupTeamMapRemoval: true, + } + assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source)) + }) +} + +func TestRequiredAdditionalInfoFailureWarningLifecycle(t *testing.T) { + c, err := cache.NewStringCache(setting.Cache{Adapter: "memory", Interval: 1}) + require.NoError(t, err) + + mockNow := time.Unix(1_700_000_000, 0) + defer timeutil.MockSet(mockNow)() + SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace") + + warning := GetRequiredAdditionalInfoFailureWarning(c) + require.NotNil(t, warning) + assert.Equal(t, "Google Workspace", warning.SourceName) + assert.Equal(t, timeutil.TimeStamp(mockNow.Unix()), warning.LastFailedUnix) + + ClearRequiredAdditionalInfoFetchFailureWarning(c) + assert.Nil(t, GetRequiredAdditionalInfoFailureWarning(c)) +} + +func TestRequiredAdditionalInfoFailureWarningThrottle(t *testing.T) { + c, err := cache.NewStringCache(setting.Cache{Adapter: "memory", Interval: 1}) + require.NoError(t, err) + + first := time.Unix(1_700_000_000, 0) + defer timeutil.MockSet(first)() + SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace") + initial := GetRequiredAdditionalInfoFailureWarning(c) + require.NotNil(t, initial) + + // Within throttle window, keep previous timestamp to avoid cache churn. + timeutil.MockSet(first.Add(30 * time.Second)) + SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace") + throttled := GetRequiredAdditionalInfoFailureWarning(c) + require.NotNil(t, throttled) + assert.Equal(t, initial.LastFailedUnix, throttled.LastFailedUnix) + + // After throttle window, timestamp is refreshed. + timeutil.MockSet(first.Add(61 * time.Second)) + SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace") + refreshed := GetRequiredAdditionalInfoFailureWarning(c) + require.NotNil(t, refreshed) + assert.Equal(t, timeutil.TimeStamp(first.Add(61*time.Second).Unix()), refreshed.LastFailedUnix) +} diff --git a/services/auth/source/oauth2/required_additional_info_failure_warning.go b/services/auth/source/oauth2/required_additional_info_failure_warning.go new file mode 100644 index 0000000000..5eee30207a --- /dev/null +++ b/services/auth/source/oauth2/required_additional_info_failure_warning.go @@ -0,0 +1,77 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "code.gitea.io/gitea/modules/cache" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/timeutil" +) + +const ( + requiredAdditionalInfoFailureWarningCacheKey = "oauth2.required.additional.info.failure.warning" + requiredAdditionalInfoFailureWarningTTL = 3600 + requiredAdditionalInfoFailureWarningThrottle = 60 +) + +type RequiredAdditionalInfoFailureWarning struct { + SourceName string `json:"sourceName"` + LastFailedUnix timeutil.TimeStamp `json:"lastFailedUnix"` +} + +func GetRequiredAdditionalInfoFailureWarning(c cache.StringCache) *RequiredAdditionalInfoFailureWarning { + if c == nil { + return nil + } + + rawWarning, ok := c.Get(requiredAdditionalInfoFailureWarningCacheKey) + if !ok || rawWarning == "" { + return nil + } + + warning := &RequiredAdditionalInfoFailureWarning{} + if err := json.Unmarshal([]byte(rawWarning), warning); err != nil { + _ = c.Delete(requiredAdditionalInfoFailureWarningCacheKey) + return nil + } + if warning.SourceName == "" || warning.LastFailedUnix.IsZero() { + _ = c.Delete(requiredAdditionalInfoFailureWarningCacheKey) + return nil + } + + return warning +} + +func SetRequiredAdditionalInfoFetchFailureWarning(c cache.StringCache, sourceName string) { + if c == nil { + return + } + if sourceName == "" { + sourceName = "OAuth2" + } + + now := timeutil.TimeStampNow() + current := GetRequiredAdditionalInfoFailureWarning(c) + if current != nil && current.SourceName == sourceName && now-current.LastFailedUnix < requiredAdditionalInfoFailureWarningThrottle { + return + } + + rawWarning, err := json.Marshal(&RequiredAdditionalInfoFailureWarning{ + SourceName: sourceName, + LastFailedUnix: now, + }) + if err != nil { + return + } + if err := c.Put(requiredAdditionalInfoFailureWarningCacheKey, string(rawWarning), requiredAdditionalInfoFailureWarningTTL); err != nil { + return + } +} + +func ClearRequiredAdditionalInfoFetchFailureWarning(c cache.StringCache) { + if c == nil { + return + } + _ = c.Delete(requiredAdditionalInfoFailureWarningCacheKey) +} diff --git a/services/context/context_template.go b/services/context/context_template.go index 0f083d097e..e5f8982293 100644 --- a/services/context/context_template.go +++ b/services/context/context_template.go @@ -17,6 +17,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/modules/web/middleware" + oauth2_source "code.gitea.io/gitea/services/auth/source/oauth2" "code.gitea.io/gitea/services/webtheme" ) @@ -78,6 +79,14 @@ func (c TemplateContext) CurrentWebBanner() *setting.WebBannerType { return nil } +func (c TemplateContext) CurrentRequiredAdditionalInfoFailureWarning() *oauth2_source.RequiredAdditionalInfoFailureWarning { + webCtx := GetWebContext(c) + if webCtx == nil || webCtx.Doer == nil || !webCtx.Doer.IsAdmin { + return nil + } + return oauth2_source.GetRequiredAdditionalInfoFailureWarning(webCtx.Cache) +} + // AppFullLink returns a full URL link with AppSubURL for the given app link (no AppSubURL) // If no link is given, it returns the current app full URL with sub-path but without trailing slash (that's why it is not named as AppURL) func (c TemplateContext) AppFullLink(link ...string) template.URL { diff --git a/templates/base/head_banner.tmpl b/templates/base/head_banner.tmpl index d237161622..dc4e76b283 100644 --- a/templates/base/head_banner.tmpl +++ b/templates/base/head_banner.tmpl @@ -9,3 +9,13 @@ {{end}} + +{{$requiredAdditionalInfoWarning := ctx.CurrentRequiredAdditionalInfoFailureWarning}} +{{if $requiredAdditionalInfoWarning}} +
+{{end}}