diff --git a/services/issue/label.go b/services/issue/label.go index 3a054d0b07..77bef8db1a 100644 --- a/services/issue/label.go +++ b/services/issue/label.go @@ -69,3 +69,29 @@ func ReplaceLabels(ctx context.Context, issue *issues_model.Issue, doer *user_mo notify_service.IssueChangeLabels(ctx, doer, issue, labels, old) return nil } + +func AddRemoveLabels(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, toAddLabels, toRemoveLabels []*issues_model.Label) error { + if len(toAddLabels) == 0 && len(toRemoveLabels) == 0 { + return nil + } + + if err := db.WithTx(ctx, func(ctx context.Context) error { + if len(toAddLabels) > 0 { + if err := issues_model.NewIssueLabels(ctx, issue, toAddLabels, doer); err != nil { + return err + } + } + + for _, label := range toRemoveLabels { + if err := issues_model.DeleteIssueLabel(ctx, issue, label, doer); err != nil { + return err + } + } + return nil + }); err != nil { + return err + } + + notify_service.IssueChangeLabels(ctx, doer, issue, toAddLabels, toRemoveLabels) + return nil +} diff --git a/services/issue/label_test.go b/services/issue/label_test.go index 093b16b53f..96ee9ba45f 100644 --- a/services/issue/label_test.go +++ b/services/issue/label_test.go @@ -59,3 +59,38 @@ func TestIssue_AddLabel(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: test.issueID, LabelID: test.labelID}) } } + +func TestIssue_AddRemoveLabels(t *testing.T) { + tests := []struct { + issueID int64 + toAddIDs []int64 + toRemoveIDs []int64 + doerID int64 + }{ + {1, []int64{2}, []int64{1}, 2}, // now there are both 1 and 2 + {1, []int64{}, []int64{1, 2}, 2}, // no label left + {1, []int64{1, 2}, []int64{}, 2}, // add them back + {1, []int64{}, []int64{}, 2}, // no-op + } + + for _, test := range tests { + assert.NoError(t, unittest.PrepareTestDatabase()) + issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: test.issueID}) + toAddLabels := make([]*issues_model.Label, len(test.toAddIDs)) + for i, labelID := range test.toAddIDs { + toAddLabels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID}) + } + toRemoveLabels := make([]*issues_model.Label, len(test.toRemoveIDs)) + for i, labelID := range test.toRemoveIDs { + toRemoveLabels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID}) + } + doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID}) + assert.NoError(t, AddRemoveLabels(t.Context(), issue, doer, toAddLabels, toRemoveLabels)) + for _, labelID := range test.toAddIDs { + unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: test.issueID, LabelID: labelID}) + } + for _, labelID := range test.toRemoveIDs { + unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: test.issueID, LabelID: labelID}) + } + } +} diff --git a/services/projects/workflow_notifier.go b/services/projects/workflow_notifier.go index 1ab446b965..19c7309e68 100644 --- a/services/projects/workflow_notifier.go +++ b/services/projects/workflow_notifier.go @@ -345,6 +345,8 @@ func matchWorkflowsFilters(workflow *project_model.Workflow, issue *issues_model } func executeWorkflowActions(ctx context.Context, workflow *project_model.Workflow, issue *issues_model.Issue) { + var toAddedLables, toRemovedLables []*issues_model.Label + for _, action := range workflow.WorkflowActions { switch action.Type { case project_model.WorkflowActionTypeColumn: @@ -373,10 +375,7 @@ func executeWorkflowActions(ctx context.Context, workflow *project_model.Workflo log.Error("GetLabelByID: %v", err) continue } - if err := issue_service.AddLabel(ctx, issue, user_model.NewProjectWorkflowsUser(), label); err != nil { - log.Error("AddLabels: %v", err) - continue - } + toAddedLables = append(toAddedLables, label) case project_model.WorkflowActionTypeRemoveLabels: labelID, _ := strconv.ParseInt(action.Value, 10, 64) if labelID == 0 { @@ -388,12 +387,7 @@ func executeWorkflowActions(ctx context.Context, workflow *project_model.Workflo log.Error("GetLabelByID: %v", err) continue } - if err := issue_service.RemoveLabel(ctx, issue, user_model.NewProjectWorkflowsUser(), label); err != nil { - if !issues_model.IsErrRepoLabelNotExist(err) { - log.Error("RemoveLabels: %v", err) - } - continue - } + toRemovedLables = append(toRemovedLables, label) case project_model.WorkflowActionTypeIssueState: if strings.EqualFold(action.Value, "reopen") { if issue.IsClosed { @@ -414,4 +408,10 @@ func executeWorkflowActions(ctx context.Context, workflow *project_model.Workflo log.Error("Unsupported action type: %s", action.Type) } } + + if len(toAddedLables)+len(toRemovedLables) > 0 { + if err := issue_service.AddRemoveLabels(ctx, issue, user_model.NewProjectWorkflowsUser(), toAddedLables, toRemovedLables); err != nil { + log.Error("AddRemoveLabels: %v", err) + } + } }