mirror of
https://github.com/go-gitea/gitea.git
synced 2026-05-10 05:21:54 +02:00
Merge b6d082d136698ff53d436fbea31df3a01aaa75e6 into ce089f498bce32305b2d9e8c6adfd8cb7c82f88f
This commit is contained in:
commit
000a842436
138
modules/google/groups.go
Normal file
138
modules/google/groups.go
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"code.gitea.io/gitea/modules/json"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
|
||||
"github.com/markbates/goth"
|
||||
)
|
||||
|
||||
const (
|
||||
IAMScope = "https://www.googleapis.com/auth/cloud-identity.groups.readonly"
|
||||
defaultIAMGroupsEndpoint = "https://content-cloudidentity.googleapis.com/v1/groups/-/memberships:searchDirectGroups"
|
||||
)
|
||||
|
||||
// maxGroupPages is the maximum number of pages fetched from the Google
|
||||
// Cloud Identity API. The API returns up to 200 groups per page by default,
|
||||
// so this caps group membership at 4,000 groups per user — far beyond any
|
||||
// realistic Google Workspace organization.
|
||||
const maxGroupPages = 20
|
||||
|
||||
// Client calls Google Workspace APIs.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
groupsEndpoint string
|
||||
claimName string
|
||||
failLoginOnAdditionalInfoError bool
|
||||
}
|
||||
|
||||
// NewClient creates a Client using the given authenticated HTTP client.
|
||||
// The client should be built from an OAuth2 token carrying IAMScope.
|
||||
func NewClient(httpClient *http.Client, claimName string, failLoginOnAdditionalInfoError bool) *Client {
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
groupsEndpoint: defaultIAMGroupsEndpoint,
|
||||
claimName: claimName,
|
||||
failLoginOnAdditionalInfoError: failLoginOnAdditionalInfoError,
|
||||
}
|
||||
}
|
||||
|
||||
// groupMembership represents a single membership entry returned by the
|
||||
// Cloud Identity Groups API searchDirectGroups endpoint.
|
||||
type groupMembership struct {
|
||||
GroupKey struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"groupKey"`
|
||||
}
|
||||
|
||||
// groupsResponse is the paged response from the Cloud Identity API.
|
||||
type groupsResponse struct {
|
||||
Memberships []groupMembership `json:"memberships"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// FetchGroups queries the Google Cloud Identity Groups API for all
|
||||
// groups the given user (identified by email) is a direct member of.
|
||||
// The caller must supply an HTTP client already authenticated with an access
|
||||
// token that carries the IAMScope scope.
|
||||
func (c *Client) FetchGroups(ctx context.Context, email string) ([]string, error) {
|
||||
groups := make([]string, 0, 16)
|
||||
pageToken := ""
|
||||
|
||||
for range make([]struct{}, maxGroupPages) {
|
||||
params := url.Values{}
|
||||
params.Set("query", fmt.Sprintf("member_key_id=='%s'", email))
|
||||
if pageToken != "" {
|
||||
params.Set("pageToken", pageToken)
|
||||
}
|
||||
apiURL := c.groupsEndpoint + "?" + params.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google groups: build request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google groups: HTTP request: %w", err)
|
||||
}
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("google groups: read response: %w", readErr)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("google groups: API returned %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var page groupsResponse
|
||||
if err := json.Unmarshal(body, &page); err != nil {
|
||||
return nil, fmt.Errorf("google groups: decode response: %w", err)
|
||||
}
|
||||
|
||||
for _, m := range page.Memberships {
|
||||
if m.GroupKey.ID != "" {
|
||||
groups = append(groups, m.GroupKey.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if page.NextPageToken == "" {
|
||||
break
|
||||
}
|
||||
pageToken = page.NextPageToken
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// FetchAdditionalInfo implements oauth2.AdditionalInfoProvider.
|
||||
// It fetches Google Workspace group memberships and injects them into
|
||||
// gothUser.RawData under the given claimName key.
|
||||
func (c *Client) FetchAdditionalInfo(ctx context.Context, user goth.User) (goth.User, error) {
|
||||
groups, err := c.FetchGroups(ctx, user.Email)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
if user.RawData == nil {
|
||||
user.RawData = make(map[string]any)
|
||||
}
|
||||
if existing, has := user.RawData[c.claimName]; has {
|
||||
log.Warn("OAuth2 Google: RawData already contains claim %q with some value. Consider to use different claim name for groups information", c.claimName, existing)
|
||||
}
|
||||
user.RawData[c.claimName] = groups
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FailLoginOnAdditionalInfoError implements oauth2.AdditionalInfoProvider.
|
||||
func (c *Client) FailLoginOnAdditionalInfoError() bool {
|
||||
return c.failLoginOnAdditionalInfoError
|
||||
}
|
||||
174
modules/google/groups_test.go
Normal file
174
modules/google/groups_test.go
Normal file
@ -0,0 +1,174 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/markbates/goth"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestClient(t *testing.T, server *httptest.Server) *Client {
|
||||
t.Helper()
|
||||
c := NewClient(server.Client(), "groups", false)
|
||||
c.groupsEndpoint = server.URL
|
||||
return c
|
||||
}
|
||||
|
||||
func mockGroupsServer(t *testing.T, expectedEmail string, pages [][]string) *httptest.Server {
|
||||
t.Helper()
|
||||
pageIndex := 0
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate the query contains the correct member_key_id
|
||||
expectedQuery := fmt.Sprintf("member_key_id=='%s'", expectedEmail)
|
||||
assert.Equal(t, expectedQuery, r.URL.Query().Get("query"))
|
||||
|
||||
var memberships []string
|
||||
for _, g := range pages[pageIndex] {
|
||||
memberships = append(memberships, fmt.Sprintf(`{"groupKey":{"id":%q}}`, g))
|
||||
}
|
||||
|
||||
nextPageToken := ""
|
||||
if pageIndex < len(pages)-1 {
|
||||
nextPageToken = "page-token"
|
||||
}
|
||||
pageIndex++
|
||||
|
||||
body := fmt.Sprintf(
|
||||
`{"memberships":[%s],"nextPageToken":%q}`,
|
||||
strings.Join(memberships, ","),
|
||||
nextPageToken,
|
||||
)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = fmt.Fprint(w, body)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestFetchGoogleGroups_SinglePage(t *testing.T) {
|
||||
server := mockGroupsServer(t, "user@example.com", [][]string{
|
||||
{"group-a@example.com", "group-b@example.com"},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
client := newTestClient(t, server)
|
||||
groups, err := client.FetchGroups(context.Background(), "user@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"group-a@example.com", "group-b@example.com"}, groups)
|
||||
}
|
||||
|
||||
func TestFetchGroups_SinglePage(t *testing.T) {
|
||||
server := mockGroupsServer(t, "user@example.com", [][]string{
|
||||
{"group-a@example.com", "group-b@example.com"},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
client := newTestClient(t, server)
|
||||
groups, err := client.FetchGroups(context.Background(), "user@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"group-a@example.com", "group-b@example.com"}, groups)
|
||||
}
|
||||
|
||||
func TestFetchGoogleGroups_MultiPage(t *testing.T) {
|
||||
server := mockGroupsServer(t, "user@example.com", [][]string{
|
||||
{"group-a@example.com"},
|
||||
{"group-b@example.com", "group-c@example.com"},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
client := newTestClient(t, server)
|
||||
groups, err := client.FetchGroups(context.Background(), "user@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"group-a@example.com", "group-b@example.com", "group-c@example.com"}, groups)
|
||||
}
|
||||
|
||||
func TestFetchGoogleGroups_Empty(t *testing.T) {
|
||||
server := mockGroupsServer(t, "user@example.com", [][]string{{}})
|
||||
defer server.Close()
|
||||
|
||||
client := newTestClient(t, server)
|
||||
groups, err := client.FetchGroups(context.Background(), "user@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, groups)
|
||||
}
|
||||
|
||||
func TestFetchGoogleGroups_APIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = fmt.Fprint(w, `{"error":"forbidden"}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestClient(t, server)
|
||||
groups, err := client.FetchGroups(context.Background(), "user@example.com")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, groups)
|
||||
assert.Contains(t, err.Error(), "403")
|
||||
}
|
||||
|
||||
func TestFetchGoogleGroups_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprint(w, `not valid json`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestClient(t, server)
|
||||
groups, err := client.FetchGroups(context.Background(), "user@example.com")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, groups)
|
||||
}
|
||||
|
||||
func TestFetchAdditionalInfo_InjectsClaimBeforeValidation(t *testing.T) {
|
||||
server := mockGroupsServer(t, "user@example.com", [][]string{
|
||||
{"required-group@example.com"},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
c := newTestClient(t, server)
|
||||
c.claimName = "groups"
|
||||
|
||||
user := goth.User{
|
||||
Email: "user@example.com",
|
||||
RawData: map[string]any{},
|
||||
}
|
||||
|
||||
enriched, err := c.FetchAdditionalInfo(context.Background(), user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the claim is present and contains the group — simulating what
|
||||
// RequiredClaimName validation would check after enrichment runs.
|
||||
groups, ok := enriched.RawData["groups"]
|
||||
require.True(t, ok, "groups claim must be present in RawData after enrichment")
|
||||
groupSlice, ok := groups.([]string)
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, groupSlice, "required-group@example.com")
|
||||
}
|
||||
|
||||
func TestFetchAdditionalInfo_ErrorDoesNotInjectClaim(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = fmt.Fprint(w, `{"error":"forbidden"}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
c := newTestClient(t, server)
|
||||
c.claimName = "groups"
|
||||
|
||||
user := goth.User{
|
||||
Email: "user@example.com",
|
||||
RawData: map[string]any{},
|
||||
}
|
||||
|
||||
enriched, err := c.FetchAdditionalInfo(context.Background(), user)
|
||||
require.Error(t, err)
|
||||
_, hasGroups := enriched.RawData["groups"]
|
||||
assert.False(t, hasGroups)
|
||||
}
|
||||
@ -428,6 +428,8 @@
|
||||
"auth.oauth.signin.error.general": "There was an error processing the authorization request: %s. If this error persists, please contact the site administrator.",
|
||||
"auth.oauth.signin.error.access_denied": "The authorization request was denied.",
|
||||
"auth.oauth.signin.error.temporarily_unavailable": "Authorization failed because the authentication server is temporarily unavailable. Please try again later.",
|
||||
"auth.oauth.required_additional_info_fetch_failed_banner.title": "OAuth2 additional information synchronization warning",
|
||||
"auth.oauth.required_additional_info_fetch_failed_banner.desc": "Gitea recently failed to retrieve required additional information from authentication source \"%s\" (last failure: %s). Users may keep their previously synchronized admin/restricted/team state until retrieval succeeds again. Check logs for details.",
|
||||
"auth.oauth_callback_unable_auto_reg": "Auto Registration is enabled, but OAuth2 Provider %[1]s returned missing fields: %[2]s, unable to create an account automatically. Please create or link to an account, or contact the site administrator.",
|
||||
"auth.openid_connect_submit": "Connect",
|
||||
"auth.openid_connect_title": "Connect to an existing account",
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
oauth2_source "code.gitea.io/gitea/services/auth/source/oauth2"
|
||||
"code.gitea.io/gitea/services/context"
|
||||
)
|
||||
|
||||
@ -69,8 +70,16 @@ type pageGlobalDataType struct {
|
||||
IsSigned bool
|
||||
IsSiteAdmin bool
|
||||
|
||||
GetNotificationUnreadCount func() int64
|
||||
GetActiveStopwatch func() *StopwatchTmplInfo
|
||||
GetNotificationUnreadCount func() int64
|
||||
GetActiveStopwatch func() *StopwatchTmplInfo
|
||||
GetRequiredAdditionalInfoFailureWarning func() *oauth2_source.RequiredAdditionalInfoFailureWarning
|
||||
}
|
||||
|
||||
func oauth2RequiredAdditionalInfoFailureWarning(ctx *context.Context) *oauth2_source.RequiredAdditionalInfoFailureWarning {
|
||||
if ctx.Doer == nil || !ctx.Doer.IsAdmin {
|
||||
return nil
|
||||
}
|
||||
return oauth2_source.GetRequiredAdditionalInfoFailureWarning(ctx.Cache)
|
||||
}
|
||||
|
||||
func PageGlobalData(ctx *context.Context) {
|
||||
@ -79,5 +88,8 @@ func PageGlobalData(ctx *context.Context) {
|
||||
data.IsSiteAdmin = ctx.Doer != nil && ctx.Doer.IsAdmin
|
||||
data.GetNotificationUnreadCount = sync.OnceValue(func() int64 { return notificationUnreadCount(ctx) })
|
||||
data.GetActiveStopwatch = sync.OnceValue(func() *StopwatchTmplInfo { return getActiveStopwatch(ctx) })
|
||||
data.GetRequiredAdditionalInfoFailureWarning = sync.OnceValue(func() *oauth2_source.RequiredAdditionalInfoFailureWarning {
|
||||
return oauth2RequiredAdditionalInfoFailureWarning(ctx)
|
||||
})
|
||||
ctx.Data["PageGlobalData"] = data
|
||||
}
|
||||
|
||||
41
routers/common/pagetmpl_test.go
Normal file
41
routers/common/pagetmpl_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/cache"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
oauth2_source "code.gitea.io/gitea/services/auth/source/oauth2"
|
||||
"code.gitea.io/gitea/services/context"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOAuth2RequiredAdditionalInfoFailureWarning_AdminOnlyVisibility(t *testing.T) {
|
||||
c, err := cache.NewStringCache(setting.Cache{Adapter: "memory", Interval: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
defer timeutil.MockSet(time.Unix(1_700_000_000, 0))()
|
||||
oauth2_source.SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace")
|
||||
|
||||
adminCtx := &context.Context{
|
||||
Doer: &user_model.User{IsAdmin: true},
|
||||
Cache: c,
|
||||
}
|
||||
nonAdminCtx := &context.Context{
|
||||
Doer: &user_model.User{IsAdmin: false},
|
||||
Cache: c,
|
||||
}
|
||||
|
||||
adminWarning := oauth2RequiredAdditionalInfoFailureWarning(adminCtx)
|
||||
require.NotNil(t, adminWarning)
|
||||
assert.Equal(t, "Google Workspace", adminWarning.SourceName)
|
||||
assert.Nil(t, oauth2RequiredAdditionalInfoFailureWarning(nonAdminCtx))
|
||||
}
|
||||
@ -232,6 +232,10 @@ func claimValueToStringSet(claimValue any) container.Set[string] {
|
||||
}
|
||||
|
||||
func syncGroupsToTeams(ctx *context.Context, source *oauth2.Source, gothUser *goth.User, u *user_model.User) error {
|
||||
if !shouldSyncFromGroupClaim(source, gothUser) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if source.GroupTeamMap != "" || source.GroupTeamMapRemoval {
|
||||
groupTeamMapping, err := auth_module.UnmarshalGroupTeamMapping(source.GroupTeamMap)
|
||||
if err != nil {
|
||||
@ -257,7 +261,22 @@ func getClaimedGroups(source *oauth2.Source, gothUser *goth.User) container.Set[
|
||||
return claimValueToStringSet(groupClaims)
|
||||
}
|
||||
|
||||
func shouldSyncFromGroupClaim(source *oauth2.Source, gothUser *goth.User) bool {
|
||||
// Keep historical behavior for all providers except Google Workspace:
|
||||
// if the claim is missing, group-derived sync still runs on an empty set.
|
||||
if source.Provider != "gplus" {
|
||||
return true
|
||||
}
|
||||
|
||||
_, hasGroupClaim := gothUser.RawData[source.GroupClaimName]
|
||||
return hasGroupClaim
|
||||
}
|
||||
|
||||
func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *goth.User) (isAdmin optional.Option[user_service.UpdateOptionField[bool]], isRestricted optional.Option[bool]) {
|
||||
if !shouldSyncFromGroupClaim(source, gothUser) {
|
||||
return isAdmin, isRestricted
|
||||
}
|
||||
|
||||
groups := getClaimedGroups(source, gothUser)
|
||||
|
||||
if source.AdminGroup != "" {
|
||||
@ -461,6 +480,32 @@ func oAuth2UserLoginCallback(ctx *context.Context, authSource *auth.Source, requ
|
||||
return nil, goth.User{}, err
|
||||
}
|
||||
|
||||
// Enrichment must run before RequiredClaimName validation so that claims
|
||||
// injected by the provider (e.g. Google Workspace groups fetched via the
|
||||
// Cloud Identity API) are available when the required-claim check executes.
|
||||
// Moving this block after the RequiredClaimName check would cause users to
|
||||
// be incorrectly rejected when RequiredClaimName references an injected claim.
|
||||
if provider := oauth2.GetAdditionalInfoProvider(oauth2Source, &gothUser); provider != nil {
|
||||
enriched, err := provider.FetchAdditionalInfo(ctx, gothUser)
|
||||
if err != nil {
|
||||
log.Warn("OAuth2: failed to fetch additional info for %s: %v", gothUser.Email, err)
|
||||
if provider.FailLoginOnAdditionalInfoError() {
|
||||
sourceName := authSource.Name
|
||||
if sourceName == "" {
|
||||
sourceName = fmt.Sprintf("source #%d", authSource.ID)
|
||||
}
|
||||
oauth2.SetRequiredAdditionalInfoFetchFailureWarning(ctx.Cache, sourceName)
|
||||
// Fail closed only when login directly depends on the group claim
|
||||
// (for example RequiredClaimName == GroupClaimName). Other
|
||||
// group-based sync features are fail-open and preserve prior state.
|
||||
return nil, goth.User{}, user_model.ErrUserProhibitLogin{Name: gothUser.UserID}
|
||||
}
|
||||
} else {
|
||||
oauth2.ClearRequiredAdditionalInfoFetchFailureWarning(ctx.Cache)
|
||||
gothUser = enriched
|
||||
}
|
||||
}
|
||||
|
||||
if oauth2Source.RequiredClaimName != "" {
|
||||
claimInterface, has := gothUser.RawData[oauth2Source.RequiredClaimName]
|
||||
if !has {
|
||||
|
||||
108
routers/web/auth/oauth_group_claims_test.go
Normal file
108
routers/web/auth/oauth_group_claims_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/modules/optional"
|
||||
"code.gitea.io/gitea/services/auth/source/oauth2"
|
||||
user_service "code.gitea.io/gitea/services/user"
|
||||
|
||||
"github.com/markbates/goth"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShouldSyncFromGroupClaim(t *testing.T) {
|
||||
t.Run("google claim missing", func(t *testing.T) {
|
||||
source := &oauth2.Source{
|
||||
Provider: "gplus",
|
||||
GroupClaimName: "groups",
|
||||
}
|
||||
user := &goth.User{
|
||||
RawData: map[string]any{},
|
||||
}
|
||||
assert.False(t, shouldSyncFromGroupClaim(source, user))
|
||||
})
|
||||
|
||||
t.Run("google claim present and empty", func(t *testing.T) {
|
||||
source := &oauth2.Source{
|
||||
Provider: "gplus",
|
||||
GroupClaimName: "groups",
|
||||
}
|
||||
user := &goth.User{
|
||||
RawData: map[string]any{
|
||||
"groups": []string{},
|
||||
},
|
||||
}
|
||||
assert.True(t, shouldSyncFromGroupClaim(source, user))
|
||||
})
|
||||
|
||||
t.Run("non google provider keeps old behavior", func(t *testing.T) {
|
||||
source := &oauth2.Source{
|
||||
Provider: "openidConnect",
|
||||
GroupClaimName: "groups",
|
||||
}
|
||||
user := &goth.User{
|
||||
RawData: map[string]any{},
|
||||
}
|
||||
assert.True(t, shouldSyncFromGroupClaim(source, user))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserAdminAndRestrictedFromGroupClaims_GoogleMissingClaim(t *testing.T) {
|
||||
source := &oauth2.Source{
|
||||
Provider: "gplus",
|
||||
GroupClaimName: "groups",
|
||||
AdminGroup: "g-admins@example.com",
|
||||
RestrictedGroup: "g-restricted@example.com",
|
||||
}
|
||||
user := &goth.User{
|
||||
RawData: map[string]any{},
|
||||
}
|
||||
|
||||
isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(source, user)
|
||||
|
||||
assert.False(t, isAdmin.Has())
|
||||
assert.Equal(t, optional.None[bool](), isRestricted)
|
||||
}
|
||||
|
||||
func TestGetUserAdminAndRestrictedFromGroupClaims_GoogleEmptyClaim(t *testing.T) {
|
||||
source := &oauth2.Source{
|
||||
Provider: "gplus",
|
||||
GroupClaimName: "groups",
|
||||
AdminGroup: "g-admins@example.com",
|
||||
RestrictedGroup: "g-restricted@example.com",
|
||||
}
|
||||
user := &goth.User{
|
||||
RawData: map[string]any{
|
||||
"groups": []string{},
|
||||
},
|
||||
}
|
||||
|
||||
isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(source, user)
|
||||
|
||||
require.True(t, isAdmin.Has())
|
||||
assert.Equal(t, user_service.UpdateOptionFieldFromSync(false), isAdmin)
|
||||
assert.Equal(t, optional.Some(false), isRestricted)
|
||||
}
|
||||
|
||||
func TestGetUserAdminAndRestrictedFromGroupClaims_NonGoogleMissingClaim(t *testing.T) {
|
||||
source := &oauth2.Source{
|
||||
Provider: "openidConnect",
|
||||
GroupClaimName: "groups",
|
||||
AdminGroup: "g-admins@example.com",
|
||||
RestrictedGroup: "g-restricted@example.com",
|
||||
}
|
||||
user := &goth.User{
|
||||
RawData: map[string]any{},
|
||||
}
|
||||
|
||||
isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(source, user)
|
||||
|
||||
require.True(t, isAdmin.Has())
|
||||
assert.Equal(t, user_service.UpdateOptionFieldFromSync(false), isAdmin)
|
||||
assert.Equal(t, optional.Some(false), isRestricted)
|
||||
}
|
||||
60
services/auth/source/oauth2/additional_info_provider.go
Normal file
60
services/auth/source/oauth2/additional_info_provider.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
google_module "code.gitea.io/gitea/modules/google"
|
||||
|
||||
"github.com/markbates/goth"
|
||||
go_oauth2 "golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// AdditionalInfoProvider is implemented by OAuth2 providers that can fetch
|
||||
// additional user information (such as group memberships) that is not
|
||||
// included in the standard token or userinfo response.
|
||||
// The provider receives the resolved goth user, and returns a modified copy
|
||||
// with any extra data injected into RawData.
|
||||
type AdditionalInfoProvider interface {
|
||||
FetchAdditionalInfo(ctx context.Context, user goth.User) (goth.User, error)
|
||||
FailLoginOnAdditionalInfoError() bool
|
||||
}
|
||||
|
||||
// GetAdditionalInfoProvider returns an AdditionalInfoProvider for the given
|
||||
// source if that provider supports fetching additional info, or nil if none
|
||||
// applies. The returned provider is already configured with an authenticated
|
||||
// HTTP client built from the access token.
|
||||
func GetAdditionalInfoProvider(source *Source, gothUser *goth.User) AdditionalInfoProvider {
|
||||
switch source.Provider {
|
||||
case "gplus":
|
||||
if slices.Contains(source.Scopes, google_module.IAMScope) {
|
||||
claimName := source.GroupClaimName
|
||||
if claimName == "" {
|
||||
claimName = "groups"
|
||||
}
|
||||
oauthToken := &go_oauth2.Token{AccessToken: gothUser.AccessToken}
|
||||
// Note: we use only the access token without a refresh token.
|
||||
// This is intentional — the token is issued moments before this
|
||||
// call during the login flow and is guaranteed to be fresh.
|
||||
authenticatedClient := go_oauth2.NewClient(context.Background(), go_oauth2.StaticTokenSource(oauthToken))
|
||||
return google_module.NewClient(authenticatedClient, claimName, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isGoogleGroupClaimRequiredForLoginFlow(source *Source) bool {
|
||||
groupClaimName := source.GroupClaimName
|
||||
if groupClaimName == "" {
|
||||
groupClaimName = "groups"
|
||||
}
|
||||
|
||||
// Fail closed only when login itself depends on the group claim.
|
||||
//
|
||||
// Admin/restricted/team sync can preserve the user's previous state when the
|
||||
// group claim is missing, so those options intentionally stay fail-open.
|
||||
return source.RequiredClaimName == groupClaimName
|
||||
}
|
||||
110
services/auth/source/oauth2/additional_info_provider_test.go
Normal file
110
services/auth/source/oauth2/additional_info_provider_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/modules/cache"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsGoogleGroupClaimRequiredForLoginFlow(t *testing.T) {
|
||||
t.Run("no group-dependent options", func(t *testing.T) {
|
||||
source := &Source{
|
||||
GroupClaimName: "groups",
|
||||
}
|
||||
assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
})
|
||||
|
||||
t.Run("required claim uses group claim", func(t *testing.T) {
|
||||
source := &Source{
|
||||
GroupClaimName: "custom_groups",
|
||||
RequiredClaimName: "custom_groups",
|
||||
}
|
||||
assert.True(t, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
})
|
||||
|
||||
t.Run("required claim uses default groups claim", func(t *testing.T) {
|
||||
source := &Source{
|
||||
RequiredClaimName: "groups",
|
||||
}
|
||||
assert.True(t, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
})
|
||||
|
||||
t.Run("admin group configured", func(t *testing.T) {
|
||||
source := &Source{
|
||||
AdminGroup: "admins@example.com",
|
||||
}
|
||||
assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
})
|
||||
|
||||
t.Run("restricted group configured", func(t *testing.T) {
|
||||
source := &Source{
|
||||
RestrictedGroup: "restricted@example.com",
|
||||
}
|
||||
assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
})
|
||||
|
||||
t.Run("group team mapping configured", func(t *testing.T) {
|
||||
source := &Source{
|
||||
GroupTeamMap: "{\"a\": {\"org\": [\"team\"]}}",
|
||||
}
|
||||
assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
})
|
||||
|
||||
t.Run("group team mapping removal enabled", func(t *testing.T) {
|
||||
source := &Source{
|
||||
GroupTeamMapRemoval: true,
|
||||
}
|
||||
assert.False(t, isGoogleGroupClaimRequiredForLoginFlow(source))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequiredAdditionalInfoFailureWarningLifecycle(t *testing.T) {
|
||||
c, err := cache.NewStringCache(setting.Cache{Adapter: "memory", Interval: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
mockNow := time.Unix(1_700_000_000, 0)
|
||||
defer timeutil.MockSet(mockNow)()
|
||||
SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace")
|
||||
|
||||
warning := GetRequiredAdditionalInfoFailureWarning(c)
|
||||
require.NotNil(t, warning)
|
||||
assert.Equal(t, "Google Workspace", warning.SourceName)
|
||||
assert.Equal(t, timeutil.TimeStamp(mockNow.Unix()), warning.LastFailedUnix)
|
||||
|
||||
ClearRequiredAdditionalInfoFetchFailureWarning(c)
|
||||
assert.Nil(t, GetRequiredAdditionalInfoFailureWarning(c))
|
||||
}
|
||||
|
||||
func TestRequiredAdditionalInfoFailureWarningThrottle(t *testing.T) {
|
||||
c, err := cache.NewStringCache(setting.Cache{Adapter: "memory", Interval: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
first := time.Unix(1_700_000_000, 0)
|
||||
defer timeutil.MockSet(first)()
|
||||
SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace")
|
||||
initial := GetRequiredAdditionalInfoFailureWarning(c)
|
||||
require.NotNil(t, initial)
|
||||
|
||||
// Within throttle window, keep previous timestamp to avoid cache churn.
|
||||
timeutil.MockSet(first.Add(30 * time.Second))
|
||||
SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace")
|
||||
throttled := GetRequiredAdditionalInfoFailureWarning(c)
|
||||
require.NotNil(t, throttled)
|
||||
assert.Equal(t, initial.LastFailedUnix, throttled.LastFailedUnix)
|
||||
|
||||
// After throttle window, timestamp is refreshed.
|
||||
timeutil.MockSet(first.Add(61 * time.Second))
|
||||
SetRequiredAdditionalInfoFetchFailureWarning(c, "Google Workspace")
|
||||
refreshed := GetRequiredAdditionalInfoFailureWarning(c)
|
||||
require.NotNil(t, refreshed)
|
||||
assert.Equal(t, timeutil.TimeStamp(first.Add(61*time.Second).Unix()), refreshed.LastFailedUnix)
|
||||
}
|
||||
@ -0,0 +1,77 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"code.gitea.io/gitea/modules/cache"
|
||||
"code.gitea.io/gitea/modules/json"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
requiredAdditionalInfoFailureWarningCacheKey = "oauth2.required.additional.info.failure.warning"
|
||||
requiredAdditionalInfoFailureWarningTTL = 3600
|
||||
requiredAdditionalInfoFailureWarningThrottle = 60
|
||||
)
|
||||
|
||||
type RequiredAdditionalInfoFailureWarning struct {
|
||||
SourceName string `json:"sourceName"`
|
||||
LastFailedUnix timeutil.TimeStamp `json:"lastFailedUnix"`
|
||||
}
|
||||
|
||||
func GetRequiredAdditionalInfoFailureWarning(c cache.StringCache) *RequiredAdditionalInfoFailureWarning {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawWarning, ok := c.Get(requiredAdditionalInfoFailureWarningCacheKey)
|
||||
if !ok || rawWarning == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
warning := &RequiredAdditionalInfoFailureWarning{}
|
||||
if err := json.Unmarshal([]byte(rawWarning), warning); err != nil {
|
||||
_ = c.Delete(requiredAdditionalInfoFailureWarningCacheKey)
|
||||
return nil
|
||||
}
|
||||
if warning.SourceName == "" || warning.LastFailedUnix.IsZero() {
|
||||
_ = c.Delete(requiredAdditionalInfoFailureWarningCacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
return warning
|
||||
}
|
||||
|
||||
func SetRequiredAdditionalInfoFetchFailureWarning(c cache.StringCache, sourceName string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if sourceName == "" {
|
||||
sourceName = "OAuth2"
|
||||
}
|
||||
|
||||
now := timeutil.TimeStampNow()
|
||||
current := GetRequiredAdditionalInfoFailureWarning(c)
|
||||
if current != nil && current.SourceName == sourceName && now-current.LastFailedUnix < requiredAdditionalInfoFailureWarningThrottle {
|
||||
return
|
||||
}
|
||||
|
||||
rawWarning, err := json.Marshal(&RequiredAdditionalInfoFailureWarning{
|
||||
SourceName: sourceName,
|
||||
LastFailedUnix: now,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := c.Put(requiredAdditionalInfoFailureWarningCacheKey, string(rawWarning), requiredAdditionalInfoFailureWarningTTL); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func ClearRequiredAdditionalInfoFetchFailureWarning(c cache.StringCache) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
_ = c.Delete(requiredAdditionalInfoFailureWarningCacheKey)
|
||||
}
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
"code.gitea.io/gitea/modules/web/middleware"
|
||||
oauth2_source "code.gitea.io/gitea/services/auth/source/oauth2"
|
||||
"code.gitea.io/gitea/services/webtheme"
|
||||
)
|
||||
|
||||
@ -78,6 +79,14 @@ func (c TemplateContext) CurrentWebBanner() *setting.WebBannerType {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c TemplateContext) CurrentRequiredAdditionalInfoFailureWarning() *oauth2_source.RequiredAdditionalInfoFailureWarning {
|
||||
webCtx := GetWebContext(c)
|
||||
if webCtx == nil || webCtx.Doer == nil || !webCtx.Doer.IsAdmin {
|
||||
return nil
|
||||
}
|
||||
return oauth2_source.GetRequiredAdditionalInfoFailureWarning(webCtx.Cache)
|
||||
}
|
||||
|
||||
// AppFullLink returns a full URL link with AppSubURL for the given app link (no AppSubURL)
|
||||
// If no link is given, it returns the current app full URL with sub-path but without trailing slash (that's why it is not named as AppURL)
|
||||
func (c TemplateContext) AppFullLink(link ...string) template.URL {
|
||||
|
||||
@ -9,3 +9,13 @@
|
||||
</button>
|
||||
</div>
|
||||
{{end}}
|
||||
|
||||
{{$requiredAdditionalInfoWarning := ctx.CurrentRequiredAdditionalInfoFailureWarning}}
|
||||
{{if $requiredAdditionalInfoWarning}}
|
||||
<div class="ui warning message web-banner-container">
|
||||
<div class="render-content markup web-banner-content">
|
||||
<strong>{{ctx.Locale.Tr "auth.oauth.required_additional_info_fetch_failed_banner.title"}}</strong>
|
||||
<p>{{ctx.Locale.Tr "auth.oauth.required_additional_info_fetch_failed_banner.desc" $requiredAdditionalInfoWarning.SourceName (DateUtils.AbsoluteShort $requiredAdditionalInfoWarning.LastFailedUnix)}}</p>
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user