diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 7bed3523ed..f7d9c3c34a 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -188,10 +188,6 @@ func SignInOAuthCallback(ctx *context.Context) { source := authSource.Cfg.(*oauth2.Source) - isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(source, &gothUser) - u.IsAdmin = isAdmin.ValueOrDefault(user_service.UpdateOptionField[bool]{FieldValue: false}).FieldValue - u.IsRestricted = isRestricted.ValueOrDefault(setting.Service.DefaultUserIsRestricted) - linkAccountData := &LinkAccountData{authSource.ID, gothUser} if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingDisabled { linkAccountData = nil @@ -373,9 +369,6 @@ func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_m opts.IsActive = optional.Some(true) } - // Update GroupClaims - opts.IsAdmin, opts.IsRestricted = getUserAdminAndRestrictedFromGroupClaims(oauth2Source, &gothUser) - if oauth2Source.GroupTeamMap != "" || oauth2Source.GroupTeamMapRemoval { if err := source_service.SyncGroupsToTeams(ctx, u, groups, groupTeamMapping, oauth2Source.GroupTeamMapRemoval); err != nil { ctx.ServerError("SyncGroupsToTeams", err) diff --git a/routers/web/auth/oauth_signin_sync.go b/routers/web/auth/oauth_signin_sync.go index a939a0e71e..f5ec66e006 100644 --- a/routers/web/auth/oauth_signin_sync.go +++ b/routers/web/auth/oauth_signin_sync.go @@ -14,6 +14,7 @@ import ( asymkey_service "gitea.dev/services/asymkey" "gitea.dev/services/auth/source/oauth2" "gitea.dev/services/context" + user_service "gitea.dev/services/user" "github.com/markbates/goth" ) @@ -50,6 +51,14 @@ func oauth2SignInSync(ctx *context.Context, authSourceID int64, u *user_model.Us } } + // sync user flags (admin/restricted) + isAdmin, isRestricted := getUserAdminAndRestrictedFromGroupClaims(oauth2Source, &gothUser) + if isAdmin.Has() || isRestricted.Has() { + if err = user_service.UpdateUser(ctx, u, &user_service.UpdateOptions{IsAdmin: isAdmin, IsRestricted: isRestricted}); err != nil { + log.Error("Unable to sync OAuth2 user admin or restricted status %s: %v", gothUser.Provider, err) + } + } + err = oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u) if err != nil { log.Error("Unable to sync OAuth2 SSH public key %s: %v", gothUser.Provider, err) diff --git a/tests/integration/auth_oauth2_test.go b/tests/integration/auth_oauth2_test.go index bcd4981149..3bb3d56e7b 100644 --- a/tests/integration/auth_oauth2_test.go +++ b/tests/integration/auth_oauth2_test.go @@ -54,7 +54,7 @@ func TestMigrateAzureADV2ToOIDC(t *testing.T) { ) // The fake OIDC server issues tokens containing both sub and oid claims, mirroring what Azure AD v2.0 returns. - srv := newFakeOIDCServer(t, subValue, oidValue) + srv := newFakeOIDCServer(t, FakeOIDCConfig{Sub: subValue, OID: oidValue}) // --- Step 1: Establish the legacy Azure AD V2 state --- // Create an azureadv2 auth source. In production this would have been the source used before the migration. @@ -138,7 +138,7 @@ func TestOIDCIgnoresStaleExternalLoginLinks(t *testing.T) { setup := func(t *testing.T, sourceName, sub, userName, email string) (*auth_model.Source, *user_model.User) { t.Helper() - srv := newFakeOIDCServerWithProfile(t, sub, sub+"-oid", email, "OIDC Test User") + srv := newFakeOIDCServer(t, FakeOIDCConfig{Sub: sub, OID: sub + "-oid", Email: email, Name: "OIDC Test User"}) addOAuth2Source(t, sourceName, oauth2.Source{ Provider: "openidConnect", ClientID: "test-client-id", @@ -191,14 +191,27 @@ func TestOIDCIgnoresStaleExternalLoginLinks(t *testing.T) { }) } -// newFakeOIDCServer starts an httptest.Server that implements the minimum OIDC endpoints needed to complete a sign-in flow: -func newFakeOIDCServer(t *testing.T, sub, oid string) *httptest.Server { - return newFakeOIDCServerWithProfile(t, sub, oid, sub+"@example.com", "OIDC Test User") +// FakeOIDCConfig holds configuration for the fake OIDC server used in tests. +type FakeOIDCConfig struct { + Sub string + OID string + Email string + Name string + Groups []string } -func newFakeOIDCServerWithProfile(t *testing.T, sub, oid, email, name string) *httptest.Server { +// newFakeOIDCServer starts a httptest.Server that implements the minimum OIDC endpoints needed to complete a sign-in flow +func newFakeOIDCServer(t *testing.T, cfg FakeOIDCConfig) *httptest.Server { t.Helper() + // Set defaults for backward compatibility with existing tests + if cfg.Email == "" { + cfg.Email = cfg.Sub + "@example.com" + } + if cfg.Name == "" { + cfg.Name = "OIDC Test User" + } + var srv *httptest.Server srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -212,11 +225,18 @@ func newFakeOIDCServerWithProfile(t *testing.T, sub, oid, email, name string) *h }) case "/token": // returns an ID token with both "sub" and "oid" claims so tests can verify which one ends up as ExternalID claims := map[string]any{ - "iss": srv.URL, - "aud": "test-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "sub": sub, - "oid": oid, + "iss": srv.URL, + "aud": "test-client-id", + "exp": time.Now().Add(time.Hour).Unix(), + "sub": cfg.Sub, + "email": cfg.Email, + "name": cfg.Name, + } + if cfg.OID != "" { + claims["oid"] = cfg.OID + } + if cfg.Groups != nil { + claims["groups"] = cfg.Groups } payload, _ := json.Marshal(claims) header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none"}`)) @@ -232,11 +252,18 @@ func newFakeOIDCServerWithProfile(t *testing.T, sub, oid, email, name string) *h }) case "/userinfo": // sub MUST match the id_token sub; goth rejects mismatches. - _ = json.NewEncoder(w).Encode(map[string]any{ - "sub": sub, - "email": email, - "name": name, - }) + response := map[string]any{ + "sub": cfg.Sub, + "email": cfg.Email, + "name": cfg.Name, + } + if cfg.OID != "" { + response["oid"] = cfg.OID + } + if cfg.Groups != nil { + response["groups"] = cfg.Groups + } + _ = json.NewEncoder(w).Encode(response) default: http.NotFound(w, r) } @@ -264,3 +291,147 @@ func doOIDCSignIn(t *testing.T, sourceName string) { callbackURL := fmt.Sprintf("/user/oauth2/%s/callback?code=test-code&state=%s", sourceName, url.QueryEscape(state)) session.MakeRequest(t, NewRequest(t, "GET", callbackURL), http.StatusSeeOther) } + +// newOIDCSource is a helper function to create a configured OAuth2 source for testing +func newOIDCSource(srv *httptest.Server, withAdmin, withRestricted bool) oauth2.Source { + src := oauth2.Source{ + Provider: "openidConnect", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + OpenIDConnectAutoDiscoveryURL: srv.URL + "/.well-known/openid-configuration", + GroupClaimName: "groups", + } + if withAdmin { + src.AdminGroup = "admins" + } + if withRestricted { + src.RestrictedGroup = "restricted-users" + } + return src +} + +// TestOAuth2GroupClaimsAppliedOnFirstLogin verifies that group claims from OAuth2/OIDC +// are correctly applied to newly created users on the first login +func TestOAuth2GroupClaimsAppliedOnFirstLogin(t *testing.T) { + defer tests.PrepareTestEnv(t)() + // Enable auto-registration to ensure first login creates user with group claims + defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)() + // Use sub claim as username for deterministic user naming + defer test.MockVariableValue(&setting.OAuth2Client.Username, setting.OAuth2UsernameUserid)() + + tt := []struct { + Name string + IsAdmin bool + IsRestricted bool + SourceName string + }{ + { + Name: "user in both admin and restricted groups", + IsAdmin: true, + IsRestricted: true, + SourceName: "test-group-claims", + }, + { + Name: "no groups", + IsAdmin: false, + IsRestricted: false, + SourceName: "test-no-groups", + }, + } + for _, tc := range tt { + t.Run(tc.Name, func(t *testing.T) { + // Set up OIDC server with group claims + srv := newFakeOIDCServer(t, FakeOIDCConfig{ + Sub: tc.SourceName, + Email: tc.SourceName + "@example.com", + Name: "Test User", + Groups: []string{"admins", "restricted-users"}, + }) + + // Ensure it's the first login so no user in database + unittest.AssertNotExistsBean(t, &user_model.User{Name: tc.SourceName}) + + addOAuth2Source(t, tc.SourceName, newOIDCSource(srv, tc.IsAdmin, tc.IsRestricted)) + + doOIDCSignIn(t, tc.SourceName) + + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{Name: tc.SourceName}) + assert.Equal(t, tc.IsAdmin, user.IsAdmin) + assert.Equal(t, tc.IsRestricted, user.IsRestricted) + assert.Equal(t, auth_model.OAuth2, user.LoginType) + }) + } +} + +// TestOAuth2GroupClaimsManualLinking tests that group claims are applied correctly +// when a user goes through the manual linking flow (auto-registration disabled). +func TestOAuth2GroupClaimsManualLinking(t *testing.T) { + defer tests.PrepareTestEnv(t)() + // Disable auto-registration to force manual linking flow + defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, false)() + defer test.MockVariableValue(&setting.Service.AllowOnlyInternalRegistration, false)() + + tt := []struct { + Name string + IsAdmin bool + IsRestricted bool + SourceName string + }{ + { + Name: "user in both admin and restricted groups", + IsAdmin: true, + IsRestricted: true, + SourceName: "test-group-claims-manual-linking", + }, + { + Name: "no groups", + IsAdmin: false, + IsRestricted: false, + SourceName: "test-no-groups-manual-linking", + }, + } + + for _, tc := range tt { + t.Run(tc.Name, func(t *testing.T) { + srv := newFakeOIDCServer(t, FakeOIDCConfig{ + Sub: tc.SourceName, + Email: tc.SourceName + "@example.com", + Name: "Manual User", + Groups: []string{"admins", "restricted-users"}, + }) + addOAuth2Source(t, tc.SourceName, newOIDCSource(srv, tc.IsAdmin, tc.IsRestricted)) + unittest.AssertNotExistsBean(t, &user_model.User{Name: tc.SourceName}) + session := emptyTestSession(t) + resp := session.MakeRequest(t, NewRequest(t, "GET", "/user/oauth2/"+tc.SourceName), http.StatusTemporaryRedirect) + + location := resp.Header().Get("Location") + u, err := url.Parse(location) + require.NoError(t, err) + state := u.Query().Get("state") + require.NotEmpty(t, state, "redirect to OIDC provider must include state") + + callbackURL := fmt.Sprintf("/user/oauth2/%s/callback?code=test-code&state=%s", tc.SourceName, url.QueryEscape(state)) + session.MakeRequest(t, NewRequest(t, "GET", callbackURL), http.StatusSeeOther) + + // Submit the form to create a new account + linkAccountResp := session.MakeRequest(t, NewRequest(t, "GET", "/user/link_account"), http.StatusOK) + // Verify we're on the link account page + assert.Contains(t, linkAccountResp.Body.String(), "link_account") + + // Use NewRequestWithValues to POST form data (no CSRF needed in tests) + // Field names are lowercase in HTML forms: user_name, email, password, retype + req := NewRequestWithValues(t, "POST", "/user/link_account_signup", map[string]string{ + "user_name": tc.SourceName, + "email": tc.SourceName + "@example.com", + "password": "", // AllowOnlyExternalRegistration means no password needed + "retype": "", + }) + session.MakeRequest(t, req, http.StatusSeeOther) + + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{Name: tc.SourceName}) + assert.Equal(t, tc.IsAdmin, user.IsAdmin) + assert.Equal(t, tc.IsRestricted, user.IsRestricted) + assert.Equal(t, auth_model.OAuth2, user.LoginType) + }) + } +}