Skip to content

Commit

Permalink
Add context parameter to some database functions (go-gitea#26055)
Browse files Browse the repository at this point in the history
To avoid deadlock problem, almost database related functions should be
have ctx as the first parameter.
This PR do a refactor for some of these functions.
  • Loading branch information
lunny authored Jul 22, 2023
1 parent c42b718 commit b167f35
Show file tree
Hide file tree
Showing 50 changed files with 209 additions and 237 deletions.
8 changes: 4 additions & 4 deletions models/activities/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ func (a *Action) GetIssueInfos() []string {
}

// GetIssueTitle returns the title of first issue associated
// with the action.
// with the action. This function will be invoked in template so keep db.DefaultContext here
func (a *Action) GetIssueTitle() string {
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
issue, err := issues_model.GetIssueByIndex(db.DefaultContext, a.RepoID, index)
if err != nil {
log.Error("GetIssueByIndex: %v", err)
return "500 when get issue"
Expand All @@ -404,9 +404,9 @@ func (a *Action) GetIssueTitle() string {

// GetIssueContent returns the content of first issue associated with
// this action.
func (a *Action) GetIssueContent() string {
func (a *Action) GetIssueContent(ctx context.Context) string {
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
issue, err := issues_model.GetIssueByIndex(ctx, a.RepoID, index)
if err != nil {
log.Error("GetIssueByIndex: %v", err)
return "500 when get issue"
Expand Down
54 changes: 27 additions & 27 deletions models/activities/repo_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ type ActivityStats struct {
func GetActivityStats(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, releases, issues, prs, code bool) (*ActivityStats, error) {
stats := &ActivityStats{Code: &git.CodeActivityStats{}}
if releases {
if err := stats.FillReleases(repo.ID, timeFrom); err != nil {
if err := stats.FillReleases(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillReleases: %w", err)
}
}
if prs {
if err := stats.FillPullRequests(repo.ID, timeFrom); err != nil {
if err := stats.FillPullRequests(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillPullRequests: %w", err)
}
}
if issues {
if err := stats.FillIssues(repo.ID, timeFrom); err != nil {
if err := stats.FillIssues(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillIssues: %w", err)
}
}
if err := stats.FillUnresolvedIssues(repo.ID, timeFrom, issues, prs); err != nil {
if err := stats.FillUnresolvedIssues(ctx, repo.ID, timeFrom, issues, prs); err != nil {
return nil, fmt.Errorf("FillUnresolvedIssues: %w", err)
}
if code {
Expand Down Expand Up @@ -205,41 +205,41 @@ func (stats *ActivityStats) PublishedReleaseCount() int {
}

// FillPullRequests returns pull request information for activity page
func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) error {
func (stats *ActivityStats) FillPullRequests(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64

// Merged pull requests
sess := pullRequestsForActivityStatement(repoID, fromTime, true)
sess := pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
sess.OrderBy("pull_request.merged_unix DESC")
stats.MergedPRs = make(issues_model.PullRequestList, 0)
if err = sess.Find(&stats.MergedPRs); err != nil {
return err
}
if err = stats.MergedPRs.LoadAttributes(); err != nil {
if err = stats.MergedPRs.LoadAttributes(ctx); err != nil {
return err
}

// Merged pull request authors
sess = pullRequestsForActivityStatement(repoID, fromTime, true)
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
return err
}
stats.MergedPRAuthorCount = count

// Opened pull requests
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
sess.OrderBy("issue.created_unix ASC")
stats.OpenedPRs = make(issues_model.PullRequestList, 0)
if err = sess.Find(&stats.OpenedPRs); err != nil {
return err
}
if err = stats.OpenedPRs.LoadAttributes(); err != nil {
if err = stats.OpenedPRs.LoadAttributes(ctx); err != nil {
return err
}

// Opened pull request authors
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
return err
}
Expand All @@ -248,8 +248,8 @@ func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) e
return nil
}

func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged bool) *xorm.Session {
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", repoID).
func pullRequestsForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, merged bool) *xorm.Session {
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", repoID).
Join("INNER", "issue", "pull_request.issue_id = issue.id")

if merged {
Expand All @@ -264,35 +264,35 @@ func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged b
}

// FillIssues returns issue information for activity page
func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
func (stats *ActivityStats) FillIssues(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64

// Closed issues
sess := issuesForActivityStatement(repoID, fromTime, true, false)
sess := issuesForActivityStatement(ctx, repoID, fromTime, true, false)
sess.OrderBy("issue.closed_unix DESC")
stats.ClosedIssues = make(issues_model.IssueList, 0)
if err = sess.Find(&stats.ClosedIssues); err != nil {
return err
}

// Closed issue authors
sess = issuesForActivityStatement(repoID, fromTime, true, false)
sess = issuesForActivityStatement(ctx, repoID, fromTime, true, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
return err
}
stats.ClosedIssueAuthorCount = count

// New issues
sess = issuesForActivityStatement(repoID, fromTime, false, false)
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
sess.OrderBy("issue.created_unix ASC")
stats.OpenedIssues = make(issues_model.IssueList, 0)
if err = sess.Find(&stats.OpenedIssues); err != nil {
return err
}

// Opened issue authors
sess = issuesForActivityStatement(repoID, fromTime, false, false)
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
return err
}
Expand All @@ -302,12 +302,12 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
}

// FillUnresolvedIssues returns unresolved issue and pull request information for activity page
func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Time, issues, prs bool) error {
func (stats *ActivityStats) FillUnresolvedIssues(ctx context.Context, repoID int64, fromTime time.Time, issues, prs bool) error {
// Check if we need to select anything
if !issues && !prs {
return nil
}
sess := issuesForActivityStatement(repoID, fromTime, false, true)
sess := issuesForActivityStatement(ctx, repoID, fromTime, false, true)
if !issues || !prs {
sess.And("issue.is_pull = ?", prs)
}
Expand All @@ -316,8 +316,8 @@ func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Tim
return sess.Find(&stats.UnresolvedIssues)
}

func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
sess := db.GetEngine(db.DefaultContext).Where("issue.repo_id = ?", repoID).
func issuesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
And("issue.is_closed = ?", closed)

if !unresolved {
Expand All @@ -336,20 +336,20 @@ func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unreso
}

// FillReleases returns release information for activity page
func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error {
func (stats *ActivityStats) FillReleases(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64

// Published releases list
sess := releasesForActivityStatement(repoID, fromTime)
sess := releasesForActivityStatement(ctx, repoID, fromTime)
sess.OrderBy("release.created_unix DESC")
stats.PublishedReleases = make([]*repo_model.Release, 0)
if err = sess.Find(&stats.PublishedReleases); err != nil {
return err
}

// Published releases authors
sess = releasesForActivityStatement(repoID, fromTime)
sess = releasesForActivityStatement(ctx, repoID, fromTime)
if _, err = sess.Select("count(distinct release.publisher_id) as `count`").Table("release").Get(&count); err != nil {
return err
}
Expand All @@ -358,8 +358,8 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error
return nil
}

func releasesForActivityStatement(repoID int64, fromTime time.Time) *xorm.Session {
return db.GetEngine(db.DefaultContext).Where("release.repo_id = ?", repoID).
func releasesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
return db.GetEngine(ctx).Where("release.repo_id = ?", repoID).
And("release.is_draft = ?", false).
And("release.created_unix >= ?", fromTime.Unix())
}
11 changes: 3 additions & 8 deletions models/issues/comment_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error {
return nil
}

// loadAttributes loads all attributes
func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
// LoadAttributes loads attributes of the comments, except for attachments and
// comments
func (comments CommentList) LoadAttributes(ctx context.Context) (err error) {
if err = comments.LoadPosters(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -501,9 +502,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) {

return comments.loadDependentIssues(ctx)
}

// LoadAttributes loads attributes of the comments, except for attachments and
// comments
func (comments CommentList) LoadAttributes() error {
return comments.loadAttributes(db.DefaultContext)
}
14 changes: 7 additions & 7 deletions models/issues/issue.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) {
return err
}

if err = issue.Comments.loadAttributes(ctx); err != nil {
if err = issue.Comments.LoadAttributes(ctx); err != nil {
return err
}
if issue.IsTimetrackerEnabled(ctx) {
Expand Down Expand Up @@ -502,15 +502,15 @@ func (issue *Issue) GetLastEventLabelFake() string {
}

// GetIssueByIndex returns raw issue without loading attributes by index in a repository.
func GetIssueByIndex(repoID, index int64) (*Issue, error) {
func GetIssueByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
if index < 1 {
return nil, ErrIssueNotExist{}
}
issue := &Issue{
RepoID: repoID,
Index: index,
}
has, err := db.GetEngine(db.DefaultContext).Get(issue)
has, err := db.GetEngine(ctx).Get(issue)
if err != nil {
return nil, err
} else if !has {
Expand All @@ -520,12 +520,12 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
}

// GetIssueWithAttrsByIndex returns issue by index in a repository.
func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) {
issue, err := GetIssueByIndex(repoID, index)
func GetIssueWithAttrsByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
issue, err := GetIssueByIndex(ctx, repoID, index)
if err != nil {
return nil, err
}
return issue, issue.LoadAttributes(db.DefaultContext)
return issue, issue.LoadAttributes(ctx)
}

// GetIssueByID returns an issue by given ID.
Expand Down Expand Up @@ -846,7 +846,7 @@ func GetPinnedIssues(ctx context.Context, repoID int64, isPull bool) ([]*Issue,
return nil, err
}

err = IssueList(issues).LoadAttributes()
err = IssueList(issues).LoadAttributes(ctx)
if err != nil {
return nil, err
}
Expand Down
8 changes: 1 addition & 7 deletions models/issues/issue_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) {
}

// loadAttributes loads all attributes, expect for attachments and comments
func (issues IssueList) loadAttributes(ctx context.Context) error {
func (issues IssueList) LoadAttributes(ctx context.Context) error {
if _, err := issues.LoadRepositories(ctx); err != nil {
return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err)
}
Expand Down Expand Up @@ -562,12 +562,6 @@ func (issues IssueList) loadAttributes(ctx context.Context) error {
return nil
}

// LoadAttributes loads attributes of the issues, except for attachments and
// comments
func (issues IssueList) LoadAttributes() error {
return issues.loadAttributes(db.DefaultContext)
}

// LoadComments loads comments
func (issues IssueList) LoadComments(ctx context.Context) error {
return issues.loadComments(ctx, builder.NewCond())
Expand Down
2 changes: 1 addition & 1 deletion models/issues/issue_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestIssueList_LoadAttributes(t *testing.T) {
unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}),
}

assert.NoError(t, issueList.LoadAttributes())
assert.NoError(t, issueList.LoadAttributes(db.DefaultContext))
for _, issue := range issueList {
assert.EqualValues(t, issue.RepoID, issue.Repo.ID)
for _, label := range issue.Labels {
Expand Down
2 changes: 1 addition & 1 deletion models/issues/issue_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) {
return nil, fmt.Errorf("unable to query Issues: %w", err)
}

if err := issues.LoadAttributes(); err != nil {
if err := issues.LoadAttributes(ctx); err != nil {
return nil, fmt.Errorf("unable to LoadAttributes for Issues: %w", err)
}

Expand Down
23 changes: 9 additions & 14 deletions models/issues/pull_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor
}

// GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged
func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) {
func GetUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
sess := db.GetEngine(db.DefaultContext).
sess := db.GetEngine(ctx).
Join("INNER", "issue", "issue.id = pull_request.issue_id").
Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub)
return prs, sess.Find(&prs)
}

// CanMaintainerWriteToBranch check whether user is a maintainer and could write to the branch
func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *user_model.User) bool {
func CanMaintainerWriteToBranch(ctx context.Context, p access_model.Permission, branch string, user *user_model.User) bool {
if p.CanWrite(unit.TypeCode) {
return true
}
Expand All @@ -69,18 +69,18 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *
return false
}

prs, err := GetUnmergedPullRequestsByHeadInfo(p.Units[0].RepoID, branch)
prs, err := GetUnmergedPullRequestsByHeadInfo(ctx, p.Units[0].RepoID, branch)
if err != nil {
return false
}

for _, pr := range prs {
if pr.AllowMaintainerEdit {
err = pr.LoadBaseRepo(db.DefaultContext)
err = pr.LoadBaseRepo(ctx)
if err != nil {
continue
}
prPerm, err := access_model.GetUserRepoPermission(db.DefaultContext, pr.BaseRepo, user)
prPerm, err := access_model.GetUserRepoPermission(ctx, pr.BaseRepo, user)
if err != nil {
continue
}
Expand All @@ -104,9 +104,9 @@ func HasUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch

// GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged
// by given base information (repo and branch).
func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) {
func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
return prs, db.GetEngine(db.DefaultContext).
return prs, db.GetEngine(ctx).
Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?",
repoID, branch, false, false).
OrderBy("issue.updated_unix DESC").
Expand Down Expand Up @@ -154,7 +154,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
// PullRequestList defines a list of pull requests
type PullRequestList []*PullRequest

func (prs PullRequestList) loadAttributes(ctx context.Context) error {
func (prs PullRequestList) LoadAttributes(ctx context.Context) error {
if len(prs) == 0 {
return nil
}
Expand Down Expand Up @@ -199,8 +199,3 @@ func (prs PullRequestList) GetIssueIDs() []int64 {
}
return issueIDs
}

// LoadAttributes load all the prs attributes
func (prs PullRequestList) LoadAttributes() error {
return prs.loadAttributes(db.DefaultContext)
}
Loading

0 comments on commit b167f35

Please sign in to comment.