From c9daa375c239ca33474a084dfdebe12183a8fcd4 Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Thu, 17 Oct 2019 16:32:53 +0200 Subject: [PATCH] Refactor user repo --- system/repository/user.go | 97 +++++++++++++++++++-------------------- system/rest/user.go | 2 + system/service/user.go | 6 +-- tests/system/user_test.go | 2 +- 4 files changed, 54 insertions(+), 53 deletions(-) diff --git a/system/repository/user.go b/system/repository/user.go index ec2e1ed49..15e9f8876 100644 --- a/system/repository/user.go +++ b/system/repository/user.go @@ -75,9 +75,8 @@ func (r user) query() squirrel.SelectBuilder { func (r user) queryNoFilter() squirrel.SelectBuilder { return squirrel. - Select(). - From(r.table() + " AS u"). - Columns(r.columns()...) + Select(r.columns()...). + From(r.table() + " AS u") } func (r *user) With(ctx context.Context, db *factory.DB) UserRepository { @@ -86,34 +85,49 @@ func (r *user) With(ctx context.Context, db *factory.DB) UserRepository { } } -func (r user) findBy(field string, value interface{}) (*types.User, error) { - var ( - query = r.query().Where("u."+field+" = ?", value) - u = &types.User{} - ) - - return u, isFound(r.fetchOne(u, query), u.ID > 0, ErrUserNotFound) -} - func (r user) FindByUsername(username string) (*types.User, error) { - return r.findBy("username", username) + return r.findOneBy("username", username) } func (r user) FindByHandle(handle string) (*types.User, error) { - return r.findBy("handle", handle) + return r.findOneBy("handle", handle) } func (r user) FindByEmail(email string) (*types.User, error) { - return r.findBy("email", email) + return r.findOneBy("email", email) } func (r user) FindByID(id uint64) (*types.User, error) { - return r.findBy("id", id) + return r.findOneBy("id", id) +} + +func (r user) findOneBy(field string, value interface{}) (*types.User, error) { + var ( + u = &types.User{} + + q = r.query(). + Where(squirrel.Eq{field: value}) + + err = rh.FetchOne(r.db(), q, u) + ) + + if err != nil { + return nil, err + } else if u.ID == 0 { + return nil, ErrUserNotFound + } + + return u, nil } func (r user) Find(filter types.UserFilter) (set types.UserSet, f types.UserFilter, err error) { f = filter - q := r.queryNoFilter() + + if f.Sort == "" { + f.Sort = "id" + } + + query := r.queryNoFilter() // Returns user filter (flt) wrapped in IF() function with cnd as condition (when cnd != nil) whereMasked := func(cnd *permissions.ResourceFilter, flt squirrel.Sqlizer) squirrel.Sqlizer { @@ -125,15 +139,15 @@ func (r user) Find(filter types.UserFilter) (set types.UserSet, f types.UserFilt } if !f.IncDeleted { - q = q.Where(squirrel.Eq{"u.deleted_at": nil}) + query = query.Where(squirrel.Eq{"u.deleted_at": nil}) } if !f.IncSuspended { - q = q.Where(squirrel.Eq{"u.suspended_at": nil}) + query = query.Where(squirrel.Eq{"u.suspended_at": nil}) } if len(f.UserID) > 0 { - q = q.Where(squirrel.Eq{"u.ID": f.UserID}) + query = query.Where(squirrel.Eq{"u.ID": f.UserID}) } if len(f.RoleID) > 0 { @@ -144,12 +158,12 @@ func (r user) Find(filter types.UserFilter) (set types.UserSet, f types.UserFilt or = append(or, squirrel.Expr("u.ID IN (SELECT rel_user FROM sys_role_member WHERE rel_role IN (?))", roleID)) } - q = q.Where(or) + query = query.Where(or) } if f.Query != "" { qs := f.Query + "%" - q = q.Where(squirrel.Or{ + query = query.Where(squirrel.Or{ squirrel.Like{"u.username": qs}, squirrel.Like{"u.handle": qs}, whereMasked(f.IsEmailUnmaskable, squirrel.Like{"u.email": qs}), @@ -158,52 +172,37 @@ func (r user) Find(filter types.UserFilter) (set types.UserSet, f types.UserFilt } if f.Email != "" { - q = q.Where(whereMasked(f.IsNameUnmaskable, squirrel.Eq{"u.name": f.Email})) + query = query.Where(whereMasked(f.IsNameUnmaskable, squirrel.Eq{"u.name": f.Email})) } if f.Username != "" { - q = q.Where(squirrel.Eq{"u.username": f.Username}) + query = query.Where(squirrel.Eq{"u.username": f.Username}) } if f.Handle != "" { - q = q.Where(squirrel.Eq{"u.handle": f.Handle}) + query = query.Where(squirrel.Eq{"u.handle": f.Handle}) } if f.Kind != "" { - q = q.Where(squirrel.Eq{"u.kind": f.Kind}) + query = query.Where(squirrel.Eq{"u.kind": f.Kind}) } if f.IsReadable != nil { - q = q.Where(f.IsReadable) + query = query.Where(f.IsReadable) } - // @todo add support for more sophisticated sorting through ql - // refactor github.com/cortezaproject/corteza-server/compose/repository/ql - // for common use (out of compose pkg) - switch f.Sort { - case "createdAt": - q = q.OrderBy("created_at") - case "updatedAt": - q = q.OrderBy("updated_at") - case "deletedAt": - q = q.OrderBy("deleted_at") - case "suspendedAt": - q = q.OrderBy("suspended_at") - case "email", "username": - q = q.OrderBy(f.Sort) - case "userID": - q = q.OrderBy("id") - default: - q = q.OrderBy("id") + var orderBy []string + if orderBy, err = rh.ParseOrder(f.Sort, r.columns()...); err != nil { + return + } else { + query = query.OrderBy(orderBy...) } - db := r.db() - - if f.Count, err = rh.Count(db, q); err != nil || f.Count == 0 { + if f.Count, err = rh.Count(r.db(), query); err != nil || f.Count == 0 { return } - return set, f, rh.FetchPaged(db, q, f.Page, f.PerPage, &set) + return set, f, rh.FetchPaged(r.db(), query, f.Page, f.PerPage, &set) } func (r user) Total() (count uint) { diff --git a/system/rest/user.go b/system/rest/user.go index 1f7e8df3e..bc8a79bc6 100644 --- a/system/rest/user.go +++ b/system/rest/user.go @@ -46,6 +46,8 @@ func (ctrl User) List(ctx context.Context, r *request.UserList) (interface{}, er IncSuspended: r.IncSuspended, IncDeleted: r.IncDeleted, + Sort: r.Sort, + PageFilter: rh.Paging(r.Page, r.PerPage), } diff --git a/system/service/user.go b/system/service/user.go index 1040fa23d..1f842b536 100644 --- a/system/service/user.go +++ b/system/service/user.go @@ -265,19 +265,19 @@ func (svc user) Update(mod *types.User) (u *types.User, err error) { func (svc user) UniqueCheck(u *types.User) (err error) { if u.Email != "" { - if ex, _ := svc.user.FindByEmail(u.Email); ex.ID > 0 && ex.ID != u.ID { + if ex, _ := svc.user.FindByEmail(u.Email); ex != nil && ex.ID > 0 && ex.ID != u.ID { return ErrUserEmailNotUnique } } if u.Username != "" { - if ex, _ := svc.user.FindByUsername(u.Username); ex.ID > 0 && ex.ID != u.ID { + if ex, _ := svc.user.FindByUsername(u.Username); ex != nil && ex.ID > 0 && ex.ID != u.ID { return ErrUserUsernameNotUnique } } if u.Handle != "" { - if ex, _ := svc.user.FindByHandle(u.Handle); ex.ID > 0 && ex.ID != u.ID { + if ex, _ := svc.user.FindByHandle(u.Handle); ex != nil && ex.ID > 0 && ex.ID != u.ID { return ErrUserHandleNotUnique } } diff --git a/tests/system/user_test.go b/tests/system/user_test.go index 8de439df9..f84e4763c 100644 --- a/tests/system/user_test.go +++ b/tests/system/user_test.go @@ -182,7 +182,7 @@ func TestUserUpdate(t *testing.T) { u := h.repoMakeUser(h.randEmail()) h.allow(types.UserPermissionResource.AppendWildcard(), "update") - newEmail := "updated-" + u.Email + newEmail := h.randEmail() h.apiInit(). Put(fmt.Sprintf("/users/%d", u.ID)).