From fe32ffe1819038f7edadd749b7c1ef84667d1c81 Mon Sep 17 00:00:00 2001
From: Lunny Xiao <xiaolunwen@gmail.com>
Date: Mon, 30 Dec 2024 10:21:57 -0800
Subject: [PATCH] Merge updatecommentattachment functions (#33044)

Extract from #32178
---
 models/issues/comment.go      | 57 ++++++++++++-----------------------
 models/issues/comment_test.go | 18 +++++++++++
 models/issues/issue_update.go | 15 ++-------
 routers/web/repo/issue.go     |  2 +-
 4 files changed, 42 insertions(+), 50 deletions(-)

diff --git a/models/issues/comment.go b/models/issues/comment.go
index e4537aa872..880a6ce8c5 100644
--- a/models/issues/comment.go
+++ b/models/issues/comment.go
@@ -592,26 +592,26 @@ func (c *Comment) LoadAttachments(ctx context.Context) error {
 	return nil
 }
 
-// UpdateAttachments update attachments by UUIDs for the comment
-func (c *Comment) UpdateAttachments(ctx context.Context, uuids []string) error {
-	ctx, committer, err := db.TxContext(ctx)
-	if err != nil {
-		return err
+// UpdateCommentAttachments update attachments by UUIDs for the comment
+func UpdateCommentAttachments(ctx context.Context, c *Comment, uuids []string) error {
+	if len(uuids) == 0 {
+		return nil
 	}
-	defer committer.Close()
-
-	attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, uuids)
-	if err != nil {
-		return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", uuids, err)
-	}
-	for i := 0; i < len(attachments); i++ {
-		attachments[i].IssueID = c.IssueID
-		attachments[i].CommentID = c.ID
-		if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil {
-			return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err)
+	return db.WithTx(ctx, func(ctx context.Context) error {
+		attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, uuids)
+		if err != nil {
+			return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", uuids, err)
 		}
-	}
-	return committer.Commit()
+		for i := 0; i < len(attachments); i++ {
+			attachments[i].IssueID = c.IssueID
+			attachments[i].CommentID = c.ID
+			if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil {
+				return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err)
+			}
+		}
+		c.Attachments = attachments
+		return nil
+	})
 }
 
 // LoadAssigneeUserAndTeam if comment.Type is CommentTypeAssignees, then load assignees
@@ -878,7 +878,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment
 	// Check comment type.
 	switch opts.Type {
 	case CommentTypeCode:
-		if err = updateAttachments(ctx, opts, comment); err != nil {
+		if err = UpdateCommentAttachments(ctx, comment, opts.Attachments); err != nil {
 			return err
 		}
 		if comment.ReviewID != 0 {
@@ -898,7 +898,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment
 		}
 		fallthrough
 	case CommentTypeReview:
-		if err = updateAttachments(ctx, opts, comment); err != nil {
+		if err = UpdateCommentAttachments(ctx, comment, opts.Attachments); err != nil {
 			return err
 		}
 	case CommentTypeReopen, CommentTypeClose:
@@ -910,23 +910,6 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment
 	return UpdateIssueCols(ctx, opts.Issue, "updated_unix")
 }
 
-func updateAttachments(ctx context.Context, opts *CreateCommentOptions, comment *Comment) error {
-	attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, opts.Attachments)
-	if err != nil {
-		return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", opts.Attachments, err)
-	}
-	for i := range attachments {
-		attachments[i].IssueID = opts.Issue.ID
-		attachments[i].CommentID = comment.ID
-		// No assign value could be 0, so ignore AllCols().
-		if _, err = db.GetEngine(ctx).ID(attachments[i].ID).Update(attachments[i]); err != nil {
-			return fmt.Errorf("update attachment [%d]: %w", attachments[i].ID, err)
-		}
-	}
-	comment.Attachments = attachments
-	return nil
-}
-
 func createDeadlineComment(ctx context.Context, doer *user_model.User, issue *Issue, newDeadlineUnix timeutil.TimeStamp) (*Comment, error) {
 	var content string
 	var commentType CommentType
diff --git a/models/issues/comment_test.go b/models/issues/comment_test.go
index d81f33f953..32b86891e8 100644
--- a/models/issues/comment_test.go
+++ b/models/issues/comment_test.go
@@ -45,6 +45,24 @@ func TestCreateComment(t *testing.T) {
 	unittest.AssertInt64InRange(t, now, then, int64(updatedIssue.UpdatedUnix))
 }
 
+func Test_UpdateCommentAttachment(t *testing.T) {
+	assert.NoError(t, unittest.PrepareTestDatabase())
+
+	comment := unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{ID: 1})
+	attachment := repo_model.Attachment{
+		Name: "test.txt",
+	}
+	assert.NoError(t, db.Insert(db.DefaultContext, &attachment))
+
+	err := issues_model.UpdateCommentAttachments(db.DefaultContext, comment, []string{attachment.UUID})
+	assert.NoError(t, err)
+
+	attachment2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Attachment{ID: attachment.ID})
+	assert.EqualValues(t, attachment.Name, attachment2.Name)
+	assert.EqualValues(t, comment.ID, attachment2.CommentID)
+	assert.EqualValues(t, comment.IssueID, attachment2.IssueID)
+}
+
 func TestFetchCodeComments(t *testing.T) {
 	assert.NoError(t, unittest.PrepareTestDatabase())
 
diff --git a/models/issues/issue_update.go b/models/issues/issue_update.go
index 03863fe968..479834045c 100644
--- a/models/issues/issue_update.go
+++ b/models/issues/issue_update.go
@@ -405,19 +405,10 @@ func NewIssueWithIndex(ctx context.Context, doer *user_model.User, opts NewIssue
 		return err
 	}
 
-	if len(opts.Attachments) > 0 {
-		attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, opts.Attachments)
-		if err != nil {
-			return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", opts.Attachments, err)
-		}
-
-		for i := 0; i < len(attachments); i++ {
-			attachments[i].IssueID = opts.Issue.ID
-			if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil {
-				return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err)
-			}
-		}
+	if err := UpdateIssueAttachments(ctx, opts.Issue.ID, opts.Attachments); err != nil {
+		return err
 	}
+
 	if err = opts.Issue.LoadAttributes(ctx); err != nil {
 		return err
 	}
diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go
index a3a4e73d7b..f150897a2d 100644
--- a/routers/web/repo/issue.go
+++ b/routers/web/repo/issue.go
@@ -602,7 +602,7 @@ func updateAttachments(ctx *context.Context, item any, files []string) error {
 		case *issues_model.Issue:
 			err = issues_model.UpdateIssueAttachments(ctx, content.ID, files)
 		case *issues_model.Comment:
-			err = content.UpdateAttachments(ctx, files)
+			err = issues_model.UpdateCommentAttachments(ctx, content, files)
 		default:
 			return fmt.Errorf("unknown Type: %T", content)
 		}