From d4e6731b1ed47358eb32ca0c062f40773d43c75b Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Sat, 28 Jul 2018 10:04:07 +0200 Subject: [PATCH] Simplified and reduced repository code --- crm/repository/module.go | 12 ++++----- sam/repository/attachment.go | 31 ++++++----------------- sam/repository/channel.go | 25 ++++++++----------- sam/repository/generics.go | 15 ++++++++++-- sam/repository/message.go | 34 ++++++------------------- sam/repository/organisation.go | 34 ++++++------------------- sam/repository/reaction.go | 33 ++++++------------------- sam/repository/team.go | 33 ++++++------------------- sam/repository/user.go | 45 +++++++++------------------------- 9 files changed, 78 insertions(+), 184 deletions(-) diff --git a/crm/repository/module.go b/crm/repository/module.go index 6f55b64a1..f18d38ce4 100644 --- a/crm/repository/module.go +++ b/crm/repository/module.go @@ -28,7 +28,7 @@ func (r module) FindByID(ctx context.Context, id uint64) (*types.Module, error) func (r module) Find(ctx context.Context) ([]*types.Module, error) { db := factory.Database.MustGet() mod := make([]*types.Module, 0) - if err := db.Select(&mod, "SELECT * FROM crm_module ORDER BY name ASC"); err != nil { + if err := db.With(ctx).Select(&mod, "SELECT * FROM crm_module ORDER BY name ASC"); err != nil { println(err.Error()) return nil, ErrDatabaseError } else { @@ -39,7 +39,7 @@ func (r module) Find(ctx context.Context) ([]*types.Module, error) { func (r module) Create(ctx context.Context, mod *types.Module) (*types.Module, error) { db := factory.Database.MustGet() mod.SetID(factory.Sonyflake.NextID()) - if err := db.Insert("crm_module", mod); err != nil { + if err := db.With(ctx).Insert("crm_module", mod); err != nil { return nil, ErrDatabaseError } else { return mod, nil @@ -48,7 +48,7 @@ func (r module) Create(ctx context.Context, mod *types.Module) (*types.Module, e func (r module) Update(ctx context.Context, mod *types.Module) (*types.Module, error) { db := factory.Database.MustGet() - if err := db.Replace("crm_module", mod); err != nil { + if err := db.With(ctx).Replace("crm_module", mod); err != nil { return nil, ErrDatabaseError } else { return mod, nil @@ -69,10 +69,10 @@ func (r module) DeleteByID(ctx context.Context, id uint64) error { // m := module{}.New() // m.SetID(r.id).SetName(r.name) // if m.GetID() > 0 { -// return m, db.Replace("crm_module", m) +// return m, db.With(ctx).Replace("crm_module", m) // } // m.SetID(factory.Sonyflake.NextID()) -// return m, db.Insert("crm_module", m) +// return m, db.With(ctx).Insert("crm_module", m) //} // //func (r module) ContentList(r *moduleContentListRequest) (interface{}, error) { @@ -83,7 +83,7 @@ func (r module) DeleteByID(ctx context.Context, id uint64) error { // } // // res := make([]ModuleContentRow, 0) -// err = db.Select(&res, "select * from crm_module order by name asc") +// err = db.With(ctx).Select(&res, "select * from crm_module order by name asc") // return res, err //} // diff --git a/sam/repository/attachment.go b/sam/repository/attachment.go index 51c0342b2..508cf7d41 100644 --- a/sam/repository/attachment.go +++ b/sam/repository/attachment.go @@ -23,19 +23,15 @@ func Attachment() attachment { func (r attachment) FindByID(ctx context.Context, id uint64) (*types.Attachment, error) { db := factory.Database.MustGet() - + sql := "SELECT * FROM attachments WHERE id = ? AND " + sqlAttachmentScope mod := &types.Attachment{} - if err := db.GetContext(ctx, mod, "SELECT * FROM attachments WHERE id = ? AND "+sqlAttachmentScope, id); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrAttachmentNotFound - } else { - return mod, nil - } + + return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrAttachmentNotFound) } func (r attachment) FindByRange(ctx context.Context, channelID, fromAttachmentID, toAttachmentID uint64) ([]*types.Attachment, error) { db := factory.Database.MustGet() + rval := make([]*types.Attachment, 0) sql := ` SELECT * @@ -44,12 +40,7 @@ func (r attachment) FindByRange(ctx context.Context, channelID, fromAttachmentID AND rel_channel = ? AND deleted_at IS NULL` - rval := make([]*types.Attachment, 0) - if err := db.SelectContext(ctx, &rval, sql, fromAttachmentID, toAttachmentID, channelID); err != nil { - return nil, err - } - - return rval, nil + return rval, db.With(ctx).Select(&rval, sql, fromAttachmentID, toAttachmentID, channelID) } func (r attachment) Create(ctx context.Context, mod *types.Attachment) (*types.Attachment, error) { @@ -62,11 +53,7 @@ func (r attachment) Create(ctx context.Context, mod *types.Attachment) (*types.A mod.SetAttachment([]byte("{}")) } - if err := db.Insert("attachments", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Insert("attachments", mod) } func (r attachment) Update(ctx context.Context, mod *types.Attachment) (*types.Attachment, error) { @@ -75,11 +62,7 @@ func (r attachment) Update(ctx context.Context, mod *types.Attachment) (*types.A now := time.Now() mod.SetUpdatedAt(&now) - if err := db.Replace("attachments", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Replace("attachments", mod) } func (r attachment) Delete(ctx context.Context, id uint64) error { diff --git a/sam/repository/channel.go b/sam/repository/channel.go index 1bd990b49..7f9293434 100644 --- a/sam/repository/channel.go +++ b/sam/repository/channel.go @@ -23,21 +23,17 @@ func Channel() channel { func (r channel) FindByID(ctx context.Context, id uint64) (*types.Channel, error) { db := factory.Database.MustGet() - mod := &types.Channel{} - if err := db.GetContext(ctx, mod, "SELECT * FROM channels WHERE id = ? AND "+sqlChannelScope, id); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrChannelNotFound - } else { - return mod, nil - } + sql := "SELECT * FROM channels WHERE id = ? AND " + sqlChannelScope + + return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrChannelNotFound) } func (r channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error) { db := factory.Database.MustGet() + params := make([]interface{}, 0) + rval := make([]*types.Channel, 0) - var params = make([]interface{}, 0) sql := "SELECT * FROM channels WHERE " + sqlChannelScope if filter != nil { @@ -49,8 +45,7 @@ func (r channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*type sql += " ORDER BY name ASC" - rval := make([]*types.Channel, 0) - return rval, db.SelectContext(ctx, &rval, sql, params...) + return rval, db.With(ctx).Select(&rval, sql, params...) } func (r channel) Create(ctx context.Context, mod *types.Channel) (*types.Channel, error) { @@ -63,7 +58,7 @@ func (r channel) Create(ctx context.Context, mod *types.Channel) (*types.Channel mod.SetMeta([]byte("{}")) } - return mod, db.Insert("channels", mod) + return mod, db.With(ctx).Insert("channels", mod) } func (r channel) Update(ctx context.Context, mod *types.Channel) (*types.Channel, error) { @@ -72,17 +67,17 @@ func (r channel) Update(ctx context.Context, mod *types.Channel) (*types.Channel now := time.Now() mod.SetUpdatedAt(&now) - return mod, db.Replace("channels", mod) + return mod, db.With(ctx).Replace("channels", mod) } func (r channel) AddMember(ctx context.Context, channelID, userID uint64) error { sql := `INSERT INTO channel_members (rel_channel, rel_user) VALUES (?, ?)` - return exec(factory.Database.MustGet().ExecContext(ctx, sql, channelID, userID)) + return exec(factory.Database.MustGet().With(ctx).Exec(sql, channelID, userID)) } func (r channel) RemoveMember(ctx context.Context, channelID, userID uint64) error { sql := `DELETE FROM channel_members WHERE rel_channel = ? AND rel_user = ?` - return exec(factory.Database.MustGet().ExecContext(ctx, sql, channelID, userID)) + return exec(factory.Database.MustGet().With(ctx).Exec(sql, channelID, userID)) } func (r channel) Archive(ctx context.Context, id uint64) error { diff --git a/sam/repository/generics.go b/sam/repository/generics.go index b413a8397..9b384f503 100644 --- a/sam/repository/generics.go +++ b/sam/repository/generics.go @@ -11,7 +11,7 @@ func simpleUpdate(ctx context.Context, tableName, columnName string, value inter sql := fmt.Sprintf("UPDATE %s SET %s = ? WHERE id = ?", tableName, columnName) - _, err = db.ExecContext(ctx, sql, value, id) + _, err = db.With(ctx).Exec(sql, value, id) return err } @@ -20,10 +20,21 @@ func simpleDelete(ctx context.Context, tableName string, id uint64) (err error) sql := fmt.Sprintf("DELETE FROM %s WHERE id = ?", tableName) - _, err = db.ExecContext(ctx, sql, id) + _, err = db.With(ctx).Exec(sql, id) return err } func exec(_ interface{}, err error) error { return err } + +// Returns err if set otherwise it returns nerr if not valid +func isFound(err error, valid bool, nerr error) error { + if err != nil { + return err + } else if !valid { + return nerr + } + + return nil +} diff --git a/sam/repository/message.go b/sam/repository/message.go index 352bc4b44..4c29d9d54 100644 --- a/sam/repository/message.go +++ b/sam/repository/message.go @@ -23,23 +23,17 @@ func Message() message { func (r message) FindByID(ctx context.Context, id uint64) (*types.Message, error) { db := factory.Database.MustGet() - + mod := &types.Message{} sql := "SELECT id, COALESCE(type,'') AS type, message, rel_user, rel_channel, COALESCE(reply_to, 0) AS reply_to FROM messages WHERE id = ? AND " + sqlMessageScope - mod := &types.Message{} - if err := db.GetContext(ctx, mod, sql, id); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrMessageNotFound - } else { - return mod, nil - } + return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrMessageNotFound) } func (r message) Find(ctx context.Context, filter *types.MessageFilter) ([]*types.Message, error) { db := factory.Database.MustGet() + params := make([]interface{}, 0) + rval := make([]*types.Message, 0) - var params = make([]interface{}, 0) sql := "SELECT id, COALESCE(type,'') AS type, message, rel_user, rel_channel, COALESCE(reply_to, 0) AS reply_to FROM messages WHERE " + sqlMessageScope if filter != nil { @@ -71,13 +65,7 @@ func (r message) Find(ctx context.Context, filter *types.MessageFilter) ([]*type sql += " LIMIT ? " params = append(params, filter.Limit) } - - rval := make([]*types.Message, 0) - if err := db.SelectContext(ctx, &rval, sql, params...); err != nil { - return nil, err - } else { - return rval, nil - } + return rval, db.With(ctx).Select(&rval, sql, params...) } func (r message) Create(ctx context.Context, mod *types.Message) (*types.Message, error) { @@ -86,11 +74,7 @@ func (r message) Create(ctx context.Context, mod *types.Message) (*types.Message mod.SetID(factory.Sonyflake.NextID()) mod.SetCreatedAt(time.Now()) - if err := db.Insert("messages", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Insert("messages", mod) } func (r message) Update(ctx context.Context, mod *types.Message) (*types.Message, error) { @@ -99,11 +83,7 @@ func (r message) Update(ctx context.Context, mod *types.Message) (*types.Message now := time.Now() mod.SetUpdatedAt(&now) - if err := db.Replace("messages", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Replace("messages", mod) } func (r message) Delete(ctx context.Context, id uint64) error { diff --git a/sam/repository/organisation.go b/sam/repository/organisation.go index cb34eea63..c39ba67b5 100644 --- a/sam/repository/organisation.go +++ b/sam/repository/organisation.go @@ -23,21 +23,16 @@ func Organisation() organisation { func (r organisation) FindByID(ctx context.Context, id uint64) (*types.Organisation, error) { db := factory.Database.MustGet() - + sql := "SELECT * FROM organisations WHERE id = ? AND " + sqlOrganisationScope mod := &types.Organisation{} - if err := db.Get(mod, "SELECT * FROM organisations WHERE id = ? AND "+sqlOrganisationScope, id); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrOrganisationNotFound - } else { - return mod, nil - } + + return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrOrganisationNotFound) } func (r organisation) Find(ctx context.Context, filter *types.OrganisationFilter) ([]*types.Organisation, error) { db := factory.Database.MustGet() - - var params = make([]interface{}, 0) + rval := make([]*types.Organisation, 0) + params := make([]interface{}, 0) sql := "SELECT * FROM organisations WHERE " + sqlOrganisationScope if filter != nil { @@ -49,12 +44,7 @@ func (r organisation) Find(ctx context.Context, filter *types.OrganisationFilter sql += " ORDER BY name ASC" - rval := make([]*types.Organisation, 0) - if err := db.Select(&rval, sql, params...); err != nil { - return nil, err - } else { - return rval, nil - } + return rval, db.With(ctx).Select(&rval, sql, params...) } func (r organisation) Create(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) { @@ -63,11 +53,7 @@ func (r organisation) Create(ctx context.Context, mod *types.Organisation) (*typ mod.SetID(factory.Sonyflake.NextID()) mod.SetCreatedAt(time.Now()) - if err := db.Insert("organisations", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Insert("organisations", mod) } func (r organisation) Update(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) { @@ -76,11 +62,7 @@ func (r organisation) Update(ctx context.Context, mod *types.Organisation) (*typ now := time.Now() mod.SetUpdatedAt(&now) - if err := db.Replace("organisations", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Replace("organisations", mod) } func (r organisation) Archive(ctx context.Context, id uint64) error { diff --git a/sam/repository/reaction.go b/sam/repository/reaction.go index 26559c965..ac5a0f64c 100644 --- a/sam/repository/reaction.go +++ b/sam/repository/reaction.go @@ -21,32 +21,23 @@ func Reaction() reaction { func (r reaction) FindByID(ctx context.Context, id uint64) (*types.Reaction, error) { db := factory.Database.MustGet() - + sql := "SELECT * FROM reactions WHERE id = ?" mod := &types.Reaction{} - if err := db.GetContext(ctx, mod, "SELECT * FROM reactions WHERE id = ?", id); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrReactionNotFound - } else { - return mod, nil - } + + return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrReactionNotFound) + } func (r reaction) FindByRange(ctx context.Context, channelID, fromReactionID, toReactionID uint64) ([]*types.Reaction, error) { db := factory.Database.MustGet() - + rval := make([]*types.Reaction, 0) sql := ` SELECT * FROM reactions WHERE rel_reaction BETWEEN ? AND ? AND rel_channel = ?` - rval := make([]*types.Reaction, 0) - if err := db.Select(&rval, sql, fromReactionID, toReactionID, channelID); err != nil { - return nil, err - } - - return rval, nil + return rval, db.With(ctx).Select(&rval, sql, fromReactionID, toReactionID, channelID) } func (r reaction) Create(ctx context.Context, mod *types.Reaction) (*types.Reaction, error) { @@ -55,19 +46,11 @@ func (r reaction) Create(ctx context.Context, mod *types.Reaction) (*types.React mod.SetID(factory.Sonyflake.NextID()) mod.SetCreatedAt(time.Now()) - if err := db.Insert("reactions", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Insert("reactions", mod) } func (r reaction) Delete(ctx context.Context, id uint64) error { db := factory.Database.MustGet() - if _, err := db.ExecContext(ctx, "DELETE FROM reactions WHERE id = ?", id); err != nil { - return err - } else { - return nil - } + return exec(db.With(ctx).Exec("DELETE FROM reactions WHERE id = ?", id)) } diff --git a/sam/repository/team.go b/sam/repository/team.go index eae0312d0..ec2654f1f 100644 --- a/sam/repository/team.go +++ b/sam/repository/team.go @@ -23,21 +23,17 @@ func Team() team { func (r team) FindByID(ctx context.Context, id uint64) (*types.Team, error) { db := factory.Database.MustGet() - + sql := "SELECT * FROM teams WHERE id = ? AND " + sqlTeamScope mod := &types.Team{} - if err := db.Get(mod, "SELECT * FROM teams WHERE id = ? AND "+sqlTeamScope, id); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrTeamNotFound - } else { - return mod, nil - } + + return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrTeamNotFound) } func (r team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team, error) { db := factory.Database.MustGet() + rval := make([]*types.Team, 0) + params := make([]interface{}, 0) - var params = make([]interface{}, 0) sql := "SELECT * FROM teams WHERE " + sqlTeamScope if filter != nil { @@ -49,12 +45,7 @@ func (r team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team sql += " ORDER BY name ASC" - rval := make([]*types.Team, 0) - if err := db.Select(&rval, sql, params...); err != nil { - return nil, err - } else { - return rval, nil - } + return rval, db.With(ctx).Select(&rval, sql, params...) } func (r team) Create(ctx context.Context, mod *types.Team) (*types.Team, error) { @@ -63,11 +54,7 @@ func (r team) Create(ctx context.Context, mod *types.Team) (*types.Team, error) mod.SetID(factory.Sonyflake.NextID()) mod.SetCreatedAt(time.Now()) - if err := db.Insert("teams", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Insert("teams", mod) } func (r team) Update(ctx context.Context, mod *types.Team) (*types.Team, error) { @@ -76,11 +63,7 @@ func (r team) Update(ctx context.Context, mod *types.Team) (*types.Team, error) now := time.Now() mod.SetUpdatedAt(&now) - if err := db.Replace("teams", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Replace("teams", mod) } func (r team) Archive(ctx context.Context, id uint64) error { diff --git a/sam/repository/user.go b/sam/repository/user.go index 1e898aa90..5c5b85b85 100644 --- a/sam/repository/user.go +++ b/sam/repository/user.go @@ -24,34 +24,24 @@ func User() user { func (r user) FindByUsername(ctx context.Context, username string) (*types.User, error) { db := factory.Database.MustGet() - + sql := "SELECT * FROM users WHERE username = ? AND " + sqlUserScope mod := &types.User{} - if err := db.Get(mod, "SELECT * FROM users WHERE username = ? AND "+sqlUserScope, username); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrUserNotFound - } else { - return mod, nil - } + + return mod, isFound(db.Get(mod, sql, username), mod.ID > 0, ErrUserNotFound) } func (r user) FindByID(ctx context.Context, id uint64) (*types.User, error) { db := factory.Database.MustGet() - + sql := "SELECT * FROM users WHERE id = ? AND " + sqlUserScope mod := &types.User{} - if err := db.Get(mod, "SELECT * FROM users WHERE id = ? AND "+sqlUserScope, id); err != nil { - return nil, err - } else if mod.ID == 0 { - return nil, ErrUserNotFound - } else { - return mod, nil - } + + return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrUserNotFound) } func (r user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) { db := factory.Database.MustGet() - - var params = make([]interface{}, 0) + rval := make([]*types.User, 0) + params := make([]interface{}, 0) sql := "SELECT * FROM users WHERE " + sqlUserScope if filter != nil { @@ -68,12 +58,7 @@ func (r user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User sql += " ORDER BY username ASC" - rval := make([]*types.User, 0) - if err := db.Select(&rval, sql, params...); err != nil { - return nil, err - } else { - return rval, nil - } + return rval, db.With(ctx).Select(&rval, sql, params...) } func (r user) Create(ctx context.Context, mod *types.User) (*types.User, error) { @@ -86,11 +71,7 @@ func (r user) Create(ctx context.Context, mod *types.User) (*types.User, error) mod.SetMeta([]byte("{}")) } - if err := db.Insert("users", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Insert("users", mod) } func (r user) Update(ctx context.Context, mod *types.User) (*types.User, error) { @@ -99,11 +80,7 @@ func (r user) Update(ctx context.Context, mod *types.User) (*types.User, error) now := time.Now() mod.SetUpdatedAt(&now) - if err := db.Replace("users", mod); err != nil { - return nil, err - } else { - return mod, nil - } + return mod, db.With(ctx).Replace("users", mod) } func (r user) Suspend(ctx context.Context, id uint64) error {