Browse Source

Refactor and fix incorrect comment (#1247)

Ethan Koenig 2 years ago
parent
commit
ec0ae5d50c

+ 1 - 1
cmd/serv.go

@@ -232,7 +232,7 @@ func runServ(c *cli.Context) error {
232 232
 				fail("internal error", "Failed to get user by key ID(%d): %v", keyID, err)
233 233
 			}
234 234
 
235
-			mode, err := models.AccessLevel(user, repo)
235
+			mode, err := models.AccessLevel(user.ID, repo)
236 236
 			if err != nil {
237 237
 				fail("Internal error", "Failed to check access: %v", err)
238 238
 			} else if mode < requestedMode {

+ 12 - 12
models/access.go

@@ -59,21 +59,21 @@ type Access struct {
59 59
 	Mode   AccessMode
60 60
 }
61 61
 
62
-func accessLevel(e Engine, user *User, repo *Repository) (AccessMode, error) {
62
+func accessLevel(e Engine, userID int64, repo *Repository) (AccessMode, error) {
63 63
 	mode := AccessModeNone
64 64
 	if !repo.IsPrivate {
65 65
 		mode = AccessModeRead
66 66
 	}
67 67
 
68
-	if user == nil {
68
+	if userID == 0 {
69 69
 		return mode, nil
70 70
 	}
71 71
 
72
-	if user.ID == repo.OwnerID {
72
+	if userID == repo.OwnerID {
73 73
 		return AccessModeOwner, nil
74 74
 	}
75 75
 
76
-	a := &Access{UserID: user.ID, RepoID: repo.ID}
76
+	a := &Access{UserID: userID, RepoID: repo.ID}
77 77
 	if has, err := e.Get(a); !has || err != nil {
78 78
 		return mode, err
79 79
 	}
@@ -81,19 +81,19 @@ func accessLevel(e Engine, user *User, repo *Repository) (AccessMode, error) {
81 81
 }
82 82
 
83 83
 // AccessLevel returns the Access a user has to a repository. Will return NoneAccess if the
84
-// user does not have access. User can be nil!
85
-func AccessLevel(user *User, repo *Repository) (AccessMode, error) {
86
-	return accessLevel(x, user, repo)
84
+// user does not have access.
85
+func AccessLevel(userID int64, repo *Repository) (AccessMode, error) {
86
+	return accessLevel(x, userID, repo)
87 87
 }
88 88
 
89
-func hasAccess(e Engine, user *User, repo *Repository, testMode AccessMode) (bool, error) {
90
-	mode, err := accessLevel(e, user, repo)
89
+func hasAccess(e Engine, userID int64, repo *Repository, testMode AccessMode) (bool, error) {
90
+	mode, err := accessLevel(e, userID, repo)
91 91
 	return testMode <= mode, err
92 92
 }
93 93
 
94
-// HasAccess returns true if someone has the request access level. User can be nil!
95
-func HasAccess(user *User, repo *Repository, testMode AccessMode) (bool, error) {
96
-	return hasAccess(x, user, repo, testMode)
94
+// HasAccess returns true if user has access to repo
95
+func HasAccess(userID int64, repo *Repository, testMode AccessMode) (bool, error) {
96
+	return hasAccess(x, userID, repo, testMode)
97 97
 }
98 98
 
99 99
 type repoAccess struct {

+ 8 - 8
models/access_test.go

@@ -25,19 +25,19 @@ func TestAccessLevel(t *testing.T) {
25 25
 	repo1 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 2, IsPrivate: false}).(*Repository)
26 26
 	repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository)
27 27
 
28
-	level, err := AccessLevel(user1, repo1)
28
+	level, err := AccessLevel(user1.ID, repo1)
29 29
 	assert.NoError(t, err)
30 30
 	assert.Equal(t, AccessModeOwner, level)
31 31
 
32
-	level, err = AccessLevel(user1, repo2)
32
+	level, err = AccessLevel(user1.ID, repo2)
33 33
 	assert.NoError(t, err)
34 34
 	assert.Equal(t, AccessModeWrite, level)
35 35
 
36
-	level, err = AccessLevel(user2, repo1)
36
+	level, err = AccessLevel(user2.ID, repo1)
37 37
 	assert.NoError(t, err)
38 38
 	assert.Equal(t, AccessModeRead, level)
39 39
 
40
-	level, err = AccessLevel(user2, repo2)
40
+	level, err = AccessLevel(user2.ID, repo2)
41 41
 	assert.NoError(t, err)
42 42
 	assert.Equal(t, AccessModeNone, level)
43 43
 }
@@ -51,19 +51,19 @@ func TestHasAccess(t *testing.T) {
51 51
 	repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository)
52 52
 
53 53
 	for _, accessMode := range accessModes {
54
-		has, err := HasAccess(user1, repo1, accessMode)
54
+		has, err := HasAccess(user1.ID, repo1, accessMode)
55 55
 		assert.NoError(t, err)
56 56
 		assert.True(t, has)
57 57
 
58
-		has, err = HasAccess(user1, repo2, accessMode)
58
+		has, err = HasAccess(user1.ID, repo2, accessMode)
59 59
 		assert.NoError(t, err)
60 60
 		assert.Equal(t, accessMode <= AccessModeWrite, has)
61 61
 
62
-		has, err = HasAccess(user2, repo1, accessMode)
62
+		has, err = HasAccess(user2.ID, repo1, accessMode)
63 63
 		assert.NoError(t, err)
64 64
 		assert.Equal(t, accessMode <= AccessModeRead, has)
65 65
 
66
-		has, err = HasAccess(user2, repo2, accessMode)
66
+		has, err = HasAccess(user2.ID, repo2, accessMode)
67 67
 		assert.NoError(t, err)
68 68
 		assert.Equal(t, accessMode <= AccessModeNone, has)
69 69
 	}

+ 9 - 18
models/issue.go

@@ -374,7 +374,7 @@ func (issue *Issue) RemoveLabel(doer *User, label *Label) error {
374 374
 		return err
375 375
 	}
376 376
 
377
-	if has, err := HasAccess(doer, issue.Repo, AccessModeWrite); err != nil {
377
+	if has, err := HasAccess(doer.ID, issue.Repo, AccessModeWrite); err != nil {
378 378
 		return err
379 379
 	} else if !has {
380 380
 		return ErrLabelNotExist{}
@@ -415,7 +415,7 @@ func (issue *Issue) ClearLabels(doer *User) (err error) {
415 415
 		return err
416 416
 	}
417 417
 
418
-	if has, err := hasAccess(sess, doer, issue.Repo, AccessModeWrite); err != nil {
418
+	if has, err := hasAccess(sess, doer.ID, issue.Repo, AccessModeWrite); err != nil {
419 419
 		return err
420 420
 	} else if !has {
421 421
 		return ErrLabelNotExist{}
@@ -809,23 +809,14 @@ func newIssue(e *xorm.Session, doer *User, opts NewIssueOptions) (err error) {
809 809
 		}
810 810
 	}
811 811
 
812
-	if opts.Issue.AssigneeID > 0 {
813
-		assignee, err := getUserByID(e, opts.Issue.AssigneeID)
814
-		if err != nil && !IsErrUserNotExist(err) {
815
-			return fmt.Errorf("getUserByID: %v", err)
812
+	if assigneeID := opts.Issue.AssigneeID; assigneeID > 0 {
813
+		valid, err := hasAccess(e, assigneeID, opts.Repo, AccessModeWrite)
814
+		if err != nil {
815
+			return fmt.Errorf("hasAccess [user_id: %d, repo_id: %d]: %v", assigneeID, opts.Repo.ID, err)
816 816
 		}
817
-
818
-		// Assume assignee is invalid and drop silently.
819
-		opts.Issue.AssigneeID = 0
820
-		if assignee != nil {
821
-			valid, err := hasAccess(e, assignee, opts.Repo, AccessModeWrite)
822
-			if err != nil {
823
-				return fmt.Errorf("hasAccess [user_id: %d, repo_id: %d]: %v", assignee.ID, opts.Repo.ID, err)
824
-			}
825
-			if valid {
826
-				opts.Issue.AssigneeID = assignee.ID
827
-				opts.Issue.Assignee = assignee
828
-			}
817
+		if !valid {
818
+			opts.Issue.AssigneeID = 0
819
+			opts.Issue.Assignee = nil
829 820
 		}
830 821
 	}
831 822
 

+ 20 - 14
models/org_team.go

@@ -139,18 +139,19 @@ func (t *Team) removeRepository(e Engine, repo *Repository, recalculate bool) (e
139 139
 		}
140 140
 	}
141 141
 
142
-	if err = t.getMembers(e); err != nil {
143
-		return fmt.Errorf("get team members: %v", err)
142
+	teamUsers, err := getTeamUsersByTeamID(e, t.ID)
143
+	if err != nil {
144
+		return fmt.Errorf("getTeamUsersByTeamID: %v", err)
144 145
 	}
145
-	for _, u := range t.Members {
146
-		has, err := hasAccess(e, u, repo, AccessModeRead)
146
+	for _, teamUser:= range teamUsers {
147
+		has, err := hasAccess(e, teamUser.UID, repo, AccessModeRead)
147 148
 		if err != nil {
148 149
 			return err
149 150
 		} else if has {
150 151
 			continue
151 152
 		}
152 153
 
153
-		if err = watchRepo(e, u.ID, repo.ID, false); err != nil {
154
+		if err = watchRepo(e, teamUser.UID, repo.ID, false); err != nil {
154 155
 			return err
155 156
 		}
156 157
 	}
@@ -399,20 +400,25 @@ func IsTeamMember(orgID, teamID, userID int64) bool {
399 400
 	return isTeamMember(x, orgID, teamID, userID)
400 401
 }
401 402
 
402
-func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) {
403
+func getTeamUsersByTeamID(e Engine, teamID int64) ([]*TeamUser, error) {
403 404
 	teamUsers := make([]*TeamUser, 0, 10)
404
-	if err = e.
405
+	return teamUsers, e.
405 406
 		Where("team_id=?", teamID).
406
-		Find(&teamUsers); err != nil {
407
+		Find(&teamUsers)
408
+}
409
+
410
+func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) {
411
+	teamUsers, err := getTeamUsersByTeamID(e, teamID)
412
+	if err != nil {
407 413
 		return nil, fmt.Errorf("get team-users: %v", err)
408 414
 	}
409
-	members := make([]*User, 0, len(teamUsers))
410
-	for i := range teamUsers {
411
-		member := new(User)
412
-		if _, err = e.Id(teamUsers[i].UID).Get(member); err != nil {
413
-			return nil, fmt.Errorf("get user '%d': %v", teamUsers[i].UID, err)
415
+	members := make([]*User, len(teamUsers))
416
+	for i, teamUser := range teamUsers {
417
+		member, err := getUserByID(e, teamUser.UID)
418
+		if err != nil {
419
+			return nil, fmt.Errorf("get user '%d': %v", teamUser.UID, err)
414 420
 		}
415
-		members = append(members, member)
421
+		members[i] = member
416 422
 	}
417 423
 	return members, nil
418 424
 }

+ 1 - 1
models/org_team_test.go

@@ -243,7 +243,7 @@ func TestDeleteTeam(t *testing.T) {
243 243
 	// check that team members don't have "leftover" access to repos
244 244
 	user := AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
245 245
 	repo := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
246
-	accessMode, err := AccessLevel(user, repo)
246
+	accessMode, err := AccessLevel(user.ID, repo)
247 247
 	assert.NoError(t, err)
248 248
 	assert.True(t, accessMode < AccessModeWrite)
249 249
 }

+ 1 - 1
models/release.go

@@ -365,7 +365,7 @@ func DeleteReleaseByID(id int64, u *User, delTag bool) error {
365 365
 		return fmt.Errorf("GetRepositoryByID: %v", err)
366 366
 	}
367 367
 
368
-	has, err := HasAccess(u, repo, AccessModeWrite)
368
+	has, err := HasAccess(u.ID, repo, AccessModeWrite)
369 369
 	if err != nil {
370 370
 		return fmt.Errorf("HasAccess: %v", err)
371 371
 	} else if !has {

+ 1 - 1
models/repo.go

@@ -531,7 +531,7 @@ func (repo *Repository) ComposeCompareURL(oldCommitID, newCommitID string) strin
531 531
 
532 532
 // HasAccess returns true when user has access to this repository
533 533
 func (repo *Repository) HasAccess(u *User) bool {
534
-	has, _ := HasAccess(u, repo, AccessModeRead)
534
+	has, _ := HasAccess(u.ID, repo, AccessModeRead)
535 535
 	return has
536 536
 }
537 537
 

+ 1 - 1
models/ssh_key.go

@@ -794,7 +794,7 @@ func DeleteDeployKey(doer *User, id int64) error {
794 794
 		if err != nil {
795 795
 			return fmt.Errorf("GetRepositoryByID: %v", err)
796 796
 		}
797
-		yes, err := HasAccess(doer, repo, AccessModeAdmin)
797
+		yes, err := HasAccess(doer.ID, repo, AccessModeAdmin)
798 798
 		if err != nil {
799 799
 			return fmt.Errorf("HasAccess: %v", err)
800 800
 		} else if !yes {

+ 3 - 3
models/user.go

@@ -478,7 +478,7 @@ func (u *User) DeleteAvatar() error {
478 478
 
479 479
 // IsAdminOfRepo returns true if user has admin or higher access of repository.
480 480
 func (u *User) IsAdminOfRepo(repo *Repository) bool {
481
-	has, err := HasAccess(u, repo, AccessModeAdmin)
481
+	has, err := HasAccess(u.ID, repo, AccessModeAdmin)
482 482
 	if err != nil {
483 483
 		log.Error(3, "HasAccess: %v", err)
484 484
 	}
@@ -487,7 +487,7 @@ func (u *User) IsAdminOfRepo(repo *Repository) bool {
487 487
 
488 488
 // IsWriterOfRepo returns true if user has write access to given repository.
489 489
 func (u *User) IsWriterOfRepo(repo *Repository) bool {
490
-	has, err := HasAccess(u, repo, AccessModeWrite)
490
+	has, err := HasAccess(u.ID, repo, AccessModeWrite)
491 491
 	if err != nil {
492 492
 		log.Error(3, "HasAccess: %v", err)
493 493
 	}
@@ -1103,7 +1103,7 @@ func GetUserByID(id int64) (*User, error) {
1103 1103
 
1104 1104
 // GetAssigneeByID returns the user with write access of repository by given ID.
1105 1105
 func GetAssigneeByID(repo *Repository, userID int64) (*User, error) {
1106
-	has, err := HasAccess(&User{ID: userID}, repo, AccessModeWrite)
1106
+	has, err := HasAccess(userID, repo, AccessModeWrite)
1107 1107
 	if err != nil {
1108 1108
 		return nil, err
1109 1109
 	} else if !has {

+ 1 - 1
modules/context/repo.go

@@ -219,7 +219,7 @@ func RepoAssignment(args ...bool) macaron.Handler {
219 219
 		if ctx.IsSigned && ctx.User.IsAdmin {
220 220
 			ctx.Repo.AccessMode = models.AccessModeOwner
221 221
 		} else {
222
-			mode, err := models.AccessLevel(ctx.User, repo)
222
+			mode, err := models.AccessLevel(ctx.User.ID, repo)
223 223
 			if err != nil {
224 224
 				ctx.Handle(500, "AccessLevel", err)
225 225
 				return

+ 2 - 2
modules/lfs/server.go

@@ -463,7 +463,7 @@ func authenticate(ctx *context.Context, repository *models.Repository, authoriza
463 463
 	}
464 464
 
465 465
 	if ctx.IsSigned {
466
-		accessCheck, _ := models.HasAccess(ctx.User, repository, accessMode)
466
+		accessCheck, _ := models.HasAccess(ctx.User.ID, repository, accessMode)
467 467
 		return accessCheck
468 468
 	}
469 469
 
@@ -499,7 +499,7 @@ func authenticate(ctx *context.Context, repository *models.Repository, authoriza
499 499
 		return false
500 500
 	}
501 501
 
502
-	accessCheck, _ := models.HasAccess(userModel, repository, accessMode)
502
+	accessCheck, _ := models.HasAccess(userModel.ID, repository, accessMode)
503 503
 	return accessCheck
504 504
 }
505 505
 

+ 1 - 1
routers/api/v1/api.go

@@ -70,7 +70,7 @@ func repoAssignment() macaron.Handler {
70 70
 		if ctx.IsSigned && ctx.User.IsAdmin {
71 71
 			ctx.Repo.AccessMode = models.AccessModeOwner
72 72
 		} else {
73
-			mode, err := models.AccessLevel(ctx.User, repo)
73
+			mode, err := models.AccessLevel(ctx.User.ID, repo)
74 74
 			if err != nil {
75 75
 				ctx.Error(500, "AccessLevel", err)
76 76
 				return

+ 3 - 3
routers/api/v1/org/team.go

@@ -131,7 +131,7 @@ func GetTeamRepos(ctx *context.APIContext) {
131 131
 	}
132 132
 	repos := make([]*api.Repository, len(team.Repos))
133 133
 	for i, repo := range team.Repos {
134
-		access, err := models.AccessLevel(ctx.User, repo)
134
+		access, err := models.AccessLevel(ctx.User.ID, repo)
135 135
 		if err != nil {
136 136
 			ctx.Error(500, "GetTeamRepos", err)
137 137
 			return
@@ -161,7 +161,7 @@ func AddTeamRepository(ctx *context.APIContext) {
161 161
 	if ctx.Written() {
162 162
 		return
163 163
 	}
164
-	if access, err := models.AccessLevel(ctx.User, repo); err != nil {
164
+	if access, err := models.AccessLevel(ctx.User.ID, repo); err != nil {
165 165
 		ctx.Error(500, "AccessLevel", err)
166 166
 		return
167 167
 	} else if access < models.AccessModeAdmin {
@@ -181,7 +181,7 @@ func RemoveTeamRepository(ctx *context.APIContext) {
181 181
 	if ctx.Written() {
182 182
 		return
183 183
 	}
184
-	if access, err := models.AccessLevel(ctx.User, repo); err != nil {
184
+	if access, err := models.AccessLevel(ctx.User.ID, repo); err != nil {
185 185
 		ctx.Error(500, "AccessLevel", err)
186 186
 		return
187 187
 	} else if access < models.AccessModeAdmin {

+ 1 - 1
routers/api/v1/repo/fork.go

@@ -20,7 +20,7 @@ func ListForks(ctx *context.APIContext) {
20 20
 	}
21 21
 	apiForks := make([]*api.Repository, len(forks))
22 22
 	for i, fork := range forks {
23
-		access, err := models.AccessLevel(ctx.User, fork)
23
+		access, err := models.AccessLevel(ctx.User.ID, fork)
24 24
 		if err != nil {
25 25
 			ctx.Error(500, "AccessLevel", err)
26 26
 			return

+ 1 - 1
routers/api/v1/repo/release.go

@@ -40,7 +40,7 @@ func ListReleases(ctx *context.APIContext) {
40 40
 		return
41 41
 	}
42 42
 	rels := make([]*api.Release, len(releases))
43
-	access, err := models.AccessLevel(ctx.User, ctx.Repo.Repository)
43
+	access, err := models.AccessLevel(ctx.User.ID, ctx.Repo.Repository)
44 44
 	if err != nil {
45 45
 		ctx.Error(500, "AccessLevel", err)
46 46
 		return

+ 3 - 3
routers/api/v1/repo/repo.go

@@ -64,7 +64,7 @@ func Search(ctx *context.APIContext) {
64 64
 			})
65 65
 			return
66 66
 		}
67
-		accessMode, err := models.AccessLevel(ctx.User, repo)
67
+		accessMode, err := models.AccessLevel(ctx.User.ID, repo)
68 68
 		if err != nil {
69 69
 			ctx.JSON(500, map[string]interface{}{
70 70
 				"ok":    false,
@@ -218,7 +218,7 @@ func Migrate(ctx *context.APIContext, form auth.MigrateRepoForm) {
218 218
 // see https://github.com/gogits/go-gogs-client/wiki/Repositories#get
219 219
 func Get(ctx *context.APIContext) {
220 220
 	repo := ctx.Repo.Repository
221
-	access, err := models.AccessLevel(ctx.User, repo)
221
+	access, err := models.AccessLevel(ctx.User.ID, repo)
222 222
 	if err != nil {
223 223
 		ctx.Error(500, "GetRepository", err)
224 224
 		return
@@ -238,7 +238,7 @@ func GetByID(ctx *context.APIContext) {
238 238
 		return
239 239
 	}
240 240
 
241
-	access, err := models.AccessLevel(ctx.User, repo)
241
+	access, err := models.AccessLevel(ctx.User.ID, repo)
242 242
 	if err != nil {
243 243
 		ctx.Error(500, "GetRepositoryByID", err)
244 244
 		return

+ 2 - 5
routers/api/v1/user/star.go

@@ -18,13 +18,10 @@ func getStarredRepos(userID int64, private bool) ([]*api.Repository, error) {
18 18
 	if err != nil {
19 19
 		return nil, err
20 20
 	}
21
-	user, err := models.GetUserByID(userID)
22
-	if err != nil {
23
-		return nil, err
24
-	}
21
+
25 22
 	repos := make([]*api.Repository, len(starredRepos))
26 23
 	for i, starred := range starredRepos {
27
-		access, err := models.AccessLevel(user, starred)
24
+		access, err := models.AccessLevel(userID, starred)
28 25
 		if err != nil {
29 26
 			return nil, err
30 27
 		}

+ 1 - 5
routers/api/v1/user/watch.go

@@ -31,14 +31,10 @@ func getWatchedRepos(userID int64, private bool) ([]*api.Repository, error) {
31 31
 	if err != nil {
32 32
 		return nil, err
33 33
 	}
34
-	user, err := models.GetUserByID(userID)
35
-	if err != nil {
36
-		return nil, err
37
-	}
38 34
 
39 35
 	repos := make([]*api.Repository, len(watchedRepos))
40 36
 	for i, watched := range watchedRepos {
41
-		access, err := models.AccessLevel(user, watched)
37
+		access, err := models.AccessLevel(userID, watched)
42 38
 		if err != nil {
43 39
 			return nil, err
44 40
 		}

+ 2 - 2
routers/repo/http.go

@@ -152,13 +152,13 @@ func HTTP(ctx *context.Context) {
152 152
 			}
153 153
 
154 154
 			if !isPublicPull {
155
-				has, err := models.HasAccess(authUser, repo, accessMode)
155
+				has, err := models.HasAccess(authUser.ID, repo, accessMode)
156 156
 				if err != nil {
157 157
 					ctx.Handle(http.StatusInternalServerError, "HasAccess", err)
158 158
 					return
159 159
 				} else if !has {
160 160
 					if accessMode == models.AccessModeRead {
161
-						has, err = models.HasAccess(authUser, repo, models.AccessModeWrite)
161
+						has, err = models.HasAccess(authUser.ID, repo, models.AccessModeWrite)
162 162
 						if err != nil {
163 163
 							ctx.Handle(http.StatusInternalServerError, "HasAccess2", err)
164 164
 							return