diff --git a/sam/repository/attachment.go b/sam/repository/attachment.go index 7c8741b68..0493f9a0d 100644 --- a/sam/repository/attachment.go +++ b/sam/repository/attachment.go @@ -1,35 +1,38 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "github.com/titpetric/factory" "time" ) +type ( + Attachment interface { + FindAttachmentByID(id uint64) (*types.Attachment, error) + FindAttachmentByRange(channelID, fromAttachmentID, toAttachmentID uint64) ([]*types.Attachment, error) + CreateAttachment(mod *types.Attachment) (*types.Attachment, error) + UpdateAttachment(mod *types.Attachment) (*types.Attachment, error) + DeleteAttachmentByID(id uint64) error + } +) + const ( sqlAttachmentScope = "deleted_at IS NULL" ErrAttachmentNotFound = repositoryError("AttachmentNotFound") ) -type ( - attachment struct{} -) +var _ Attachment = &repository{} -func Attachment() attachment { - return attachment{} -} - -func (r attachment) FindByID(ctx context.Context, id uint64) (*types.Attachment, error) { +func (r *repository) FindAttachmentByID(id uint64) (*types.Attachment, error) { db := factory.Database.MustGet() sql := "SELECT * FROM attachments WHERE id = ? AND " + sqlAttachmentScope mod := &types.Attachment{} - return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrAttachmentNotFound) + return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrAttachmentNotFound) } -func (r attachment) FindByRange(ctx context.Context, channelID, fromAttachmentID, toAttachmentID uint64) ([]*types.Attachment, error) { +func (r *repository) FindAttachmentByRange(channelID, fromAttachmentID, toAttachmentID uint64) ([]*types.Attachment, error) { db := factory.Database.MustGet() rval := make([]*types.Attachment, 0) @@ -40,25 +43,25 @@ func (r attachment) FindByRange(ctx context.Context, channelID, fromAttachmentID AND rel_channel = ? AND deleted_at IS NULL` - return rval, db.With(ctx).Select(&rval, sql, fromAttachmentID, toAttachmentID, channelID) + return rval, db.With(r.ctx).Select(&rval, sql, fromAttachmentID, toAttachmentID, channelID) } -func (r attachment) Create(ctx context.Context, mod *types.Attachment) (*types.Attachment, error) { +func (r *repository) CreateAttachment(mod *types.Attachment) (*types.Attachment, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() mod.Attachment = coalesceJson(mod.Attachment, []byte("{}")) - return mod, factory.Database.MustGet().With(ctx).Insert("attachments", mod) + return mod, factory.Database.MustGet().With(r.ctx).Insert("attachments", mod) } -func (r attachment) Update(ctx context.Context, mod *types.Attachment) (*types.Attachment, error) { +func (r *repository) UpdateAttachment(mod *types.Attachment) (*types.Attachment, error) { mod.UpdatedAt = timeNowPtr() mod.Attachment = coalesceJson(mod.Attachment, []byte("{}")) - return mod, factory.Database.MustGet().With(ctx).Replace("attachments", mod) + return mod, factory.Database.MustGet().With(r.ctx).Replace("attachments", mod) } -func (r attachment) Delete(ctx context.Context, id uint64) error { - return simpleDelete(ctx, "attachments", id) +func (r *repository) DeleteAttachmentByID(id uint64) error { + return simpleDelete(r.ctx, "attachments", id) } diff --git a/sam/repository/attachment_test.go b/sam/repository/attachment_test.go index ad9e1e433..4fbc75201 100644 --- a/sam/repository/attachment_test.go +++ b/sam/repository/attachment_test.go @@ -22,7 +22,7 @@ func TestAttachment(t *testing.T) { att.ChannelID = 1 - att, err = rpo.Create(ctx, att) + att, err = rpo.CreateAttachment(ctx, att) must(t, err) if att.ChannelID != 1 { t.Fatal("Changes were not stored") @@ -30,23 +30,23 @@ func TestAttachment(t *testing.T) { att.ChannelID = 2 - att, err = rpo.Update(ctx, att) + att, err = rpo.UpdateAttachment(ctx, att) must(t, err) if att.ChannelID != 2 { t.Fatal("Changes were not stored") } - att, err = rpo.FindByID(ctx, att.ID) + att, err = rpo.FindAttachmentByID(ctx, att.ID) must(t, err) if att.ChannelID != 2 { t.Fatal("Changes were not stored") } - aa, err = rpo.FindByRange(ctx, 2, 0, att.ID) + aa, err = rpo.FindAttachmentByRange(ctx, 2, 0, att.ID) must(t, err) if len(aa) == 0 { t.Fatal("No results found") } - must(t, rpo.Delete(ctx, att.ID)) + must(t, rpo.DeleteAttachmentByID(ctx, att.ID)) } diff --git a/sam/repository/channel.go b/sam/repository/channel.go index 9950ebc32..f39820dac 100644 --- a/sam/repository/channel.go +++ b/sam/repository/channel.go @@ -1,35 +1,40 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "github.com/titpetric/factory" "time" ) +type ( + Channel interface { + FindChannelByID(id uint64) (*types.Channel, error) + FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error) + CreateChannel(mod *types.Channel) (*types.Channel, error) + UpdateChannel(mod *types.Channel) (*types.Channel, error) + AddChannelMember(channelID, userID uint64) error + RemoveChannelMember(channelID, userID uint64) error + ArchiveChannelByID(id uint64) error + UnarchiveChannelByID(id uint64) error + DeleteChannelByID(id uint64) error + } +) + const ( sqlChannelScope = "deleted_at IS NULL AND archived_at IS NULL" ErrChannelNotFound = repositoryError("ChannelNotFound") ) -type ( - channel struct{} -) - -func Channel() channel { - return channel{} -} - -func (r channel) FindByID(ctx context.Context, id uint64) (*types.Channel, error) { +func (r *repository) FindChannelByID(id uint64) (*types.Channel, error) { db := factory.Database.MustGet() mod := &types.Channel{} sql := "SELECT * FROM channels WHERE id = ? AND " + sqlChannelScope - return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrChannelNotFound) + return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrChannelNotFound) } -func (r channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error) { +func (r *repository) FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error) { db := factory.Database.MustGet() params := make([]interface{}, 0) rval := make([]*types.Channel, 0) @@ -45,42 +50,42 @@ func (r channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*type sql += " ORDER BY name ASC" - return rval, db.With(ctx).Select(&rval, sql, params...) + return rval, db.With(r.ctx).Select(&rval, sql, params...) } -func (r channel) Create(ctx context.Context, mod *types.Channel) (*types.Channel, error) { +func (r *repository) CreateChannel(mod *types.Channel) (*types.Channel, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() mod.Meta = coalesceJson(mod.Meta, []byte("{}")) - return mod, factory.Database.MustGet().With(ctx).Insert("channels", mod) + return mod, factory.Database.MustGet().With(r.ctx).Insert("channels", mod) } -func (r channel) Update(ctx context.Context, mod *types.Channel) (*types.Channel, error) { +func (r *repository) UpdateChannel(mod *types.Channel) (*types.Channel, error) { mod.UpdatedAt = timeNowPtr() mod.Meta = coalesceJson(mod.Meta, []byte("{}")) - return mod, factory.Database.MustGet().With(ctx).Replace("channels", mod) + return mod, factory.Database.MustGet().With(r.ctx).Replace("channels", mod) } -func (r channel) AddMember(ctx context.Context, channelID, userID uint64) error { +func (r *repository) AddChannelMember(channelID, userID uint64) error { sql := `INSERT INTO channel_members (rel_channel, rel_user) VALUES (?, ?)` - return exec(factory.Database.MustGet().With(ctx).Exec(sql, channelID, userID)) + return exec(factory.Database.MustGet().With(r.ctx).Exec(sql, channelID, userID)) } -func (r channel) RemoveMember(ctx context.Context, channelID, userID uint64) error { +func (r *repository) RemoveChannelMember(channelID, userID uint64) error { sql := `DELETE FROM channel_members WHERE rel_channel = ? AND rel_user = ?` - return exec(factory.Database.MustGet().With(ctx).Exec(sql, channelID, userID)) + return exec(factory.Database.MustGet().With(r.ctx).Exec(sql, channelID, userID)) } -func (r channel) Archive(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "channels", "archived_at", time.Now(), id) +func (r *repository) ArchiveChannelByID(id uint64) error { + return simpleUpdate(r.ctx, "channels", "archived_at", time.Now(), id) } -func (r channel) Unarchive(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "channels", "archived_at", nil, id) +func (r *repository) UnarchiveChannelByID(id uint64) error { + return simpleUpdate(r.ctx, "channels", "archived_at", nil, id) } -func (r channel) Delete(ctx context.Context, id uint64) error { - return simpleDelete(ctx, "channels", id) +func (r *repository) DeleteChannelByID(id uint64) error { + return simpleDelete(r.ctx, "channels", id) } diff --git a/sam/repository/channel_test.go b/sam/repository/channel_test.go index 223c75dd1..9686a9fc3 100644 --- a/sam/repository/channel_test.go +++ b/sam/repository/channel_test.go @@ -24,7 +24,7 @@ func TestChannel(t *testing.T) { chn.Name = name1 - chn, err = rpo.Create(ctx, chn) + chn, err = rpo.CreateChannel(ctx, chn) must(t, err) if chn.Name != name1 { t.Fatal("Changes were not stored") @@ -32,27 +32,27 @@ func TestChannel(t *testing.T) { chn.Name = name2 - chn, err = rpo.Update(ctx, chn) + chn, err = rpo.UpdateChannel(ctx, chn) must(t, err) if chn.Name != name2 { t.Fatal("Changes were not stored") } - chn, err = rpo.FindByID(ctx, chn.ID) + chn, err = rpo.FindChannelByID(ctx, chn.ID) must(t, err) if chn.Name != name2 { t.Fatal("Changes were not stored") } - cc, err = rpo.Find(ctx, &types.ChannelFilter{Query: name2}) + cc, err = rpo.FindChannels(ctx, &types.ChannelFilter{Query: name2}) must(t, err) if len(cc) == 0 { t.Fatal("No results found") } - must(t, rpo.Archive(ctx, chn.ID)) - must(t, rpo.Unarchive(ctx, chn.ID)) - must(t, rpo.Delete(ctx, chn.ID)) + must(t, rpo.ArchiveChannel(ctx, chn.ID)) + must(t, rpo.UnarchiveChannel(ctx, chn.ID)) + must(t, rpo.DeleteChannelByID(ctx, chn.ID)) } func TestChannelMembers(t *testing.T) { @@ -67,13 +67,13 @@ func TestChannelMembers(t *testing.T) { ctx := context.Background() chn := &types.Channel{} - chn, err = rpo.Create(ctx, chn) + chn, err = rpo.CreateChannel(ctx, chn) must(t, err) usr := &types.User{} usr, err = User().Create(ctx, usr) must(t, err) - must(t, rpo.AddMember(ctx, chn.ID, usr.ID)) - must(t, rpo.RemoveMember(ctx, chn.ID, usr.ID)) + must(t, rpo.AddChannelMember(ctx, chn.ID, usr.ID)) + must(t, rpo.RemoveChannelMember(ctx, chn.ID, usr.ID)) } diff --git a/sam/repository/message.go b/sam/repository/message.go index b426a5eb8..dd1b2de61 100644 --- a/sam/repository/message.go +++ b/sam/repository/message.go @@ -1,35 +1,36 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "github.com/titpetric/factory" "time" ) +type ( + Message interface { + FindMessageByID(id uint64) (*types.Message, error) + FindMessages(filter *types.MessageFilter) ([]*types.Message, error) + CreateMessage(mod *types.Message) (*types.Message, error) + UpdateMessage(mod *types.Message) (*types.Message, error) + DeleteMessageByID(id uint64) error + } +) + const ( sqlMessageScope = "deleted_at IS NULL" ErrMessageNotFound = repositoryError("MessageNotFound") ) -type ( - message struct{} -) - -func Message() message { - return message{} -} - -func (r message) FindByID(ctx context.Context, id uint64) (*types.Message, error) { +func (r *repository) FindMessageByID(id uint64) (*types.Message, error) { db := factory.Database.MustGet() mod := &types.Message{} sql := "SELECT id, COALESCE(type,'') AS type, message, rel_user, rel_channel, COALESCE(reply_to, 0) AS reply_to FROM messages WHERE id = ? AND " + sqlMessageScope - return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrMessageNotFound) + return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrMessageNotFound) } -func (r message) Find(ctx context.Context, filter *types.MessageFilter) ([]*types.Message, error) { +func (r *repository) FindMessages(filter *types.MessageFilter) ([]*types.Message, error) { db := factory.Database.MustGet() params := make([]interface{}, 0) rval := make([]*types.Message, 0) @@ -65,22 +66,22 @@ func (r message) Find(ctx context.Context, filter *types.MessageFilter) ([]*type sql += " LIMIT ? " params = append(params, filter.Limit) } - return rval, db.With(ctx).Select(&rval, sql, params...) + return rval, db.With(r.ctx).Select(&rval, sql, params...) } -func (r message) Create(ctx context.Context, mod *types.Message) (*types.Message, error) { +func (r *repository) CreateMessage(mod *types.Message) (*types.Message, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() - return mod, factory.Database.MustGet().With(ctx).Insert("messages", mod) + return mod, factory.Database.MustGet().With(r.ctx).Insert("messages", mod) } -func (r message) Update(ctx context.Context, mod *types.Message) (*types.Message, error) { +func (r *repository) UpdateMessage(mod *types.Message) (*types.Message, error) { mod.UpdatedAt = timeNowPtr() - return mod, factory.Database.MustGet().With(ctx).Replace("messages", mod) + return mod, factory.Database.MustGet().With(r.ctx).Replace("messages", mod) } -func (r message) Delete(ctx context.Context, id uint64) error { - return simpleDelete(ctx, "messages", id) +func (r *repository) DeleteMessageByID(id uint64) error { + return simpleDelete(r.ctx, "messages", id) } diff --git a/sam/repository/message_test.go b/sam/repository/message_test.go index 26b01e76a..63e3d143b 100644 --- a/sam/repository/message_test.go +++ b/sam/repository/message_test.go @@ -24,7 +24,7 @@ func TestMessage(t *testing.T) { msg.Message = msg1 - msg, err = rpo.Create(ctx, msg) + msg, err = rpo.CreateMessage(ctx, msg) must(t, err) if msg.Message != msg1 { t.Fatal("Changes were not stored") @@ -38,13 +38,13 @@ func TestMessage(t *testing.T) { t.Fatal("Changes were not stored") } - msg, err = rpo.FindByID(ctx, msg.ID) + msg, err = rpo.FindMessageByID(ctx, msg.ID) must(t, err) if msg.Message != msg2 { t.Fatal("Changes were not stored") } - mm, err = rpo.Find(ctx, &types.MessageFilter{Query: msg2}) + mm, err = rpo.FindMessages(ctx, &types.MessageFilter{Query: msg2}) must(t, err) if len(mm) == 0 { t.Fatal("No results found") diff --git a/sam/repository/organisation.go b/sam/repository/organisation.go index 8f0bd9699..c4241ca15 100644 --- a/sam/repository/organisation.go +++ b/sam/repository/organisation.go @@ -1,27 +1,30 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "github.com/titpetric/factory" "time" ) +type ( + Organisation interface { + FindOrganisationByID(id uint64) (*types.Organisation, error) + FindOrganisations(filter *types.OrganisationFilter) ([]*types.Organisation, error) + CreateOrganisation(mod *types.Organisation) (*types.Organisation, error) + UpdateOrganisation(mod *types.Organisation) (*types.Organisation, error) + ArchiveOrganisationByID(id uint64) error + UnarchiveOrganisationByID(id uint64) error + DeleteOrganisationByID(id uint64) error + } +) + const ( sqlOrganisationScope = "deleted_at IS NULL AND archived_at IS NULL" ErrOrganisationNotFound = repositoryError("OrganisationNotFound") ) -type ( - organisation struct{} -) - -func Organisation() organisation { - return organisation{} -} - -func (r organisation) FindByID(ctx context.Context, id uint64) (*types.Organisation, error) { +func (r *repository) FindOrganisationByID(id uint64) (*types.Organisation, error) { db := factory.Database.MustGet() sql := "SELECT * FROM organisations WHERE id = ? AND " + sqlOrganisationScope mod := &types.Organisation{} @@ -29,7 +32,7 @@ func (r organisation) FindByID(ctx context.Context, id uint64) (*types.Organisat return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrOrganisationNotFound) } -func (r organisation) Find(ctx context.Context, filter *types.OrganisationFilter) ([]*types.Organisation, error) { +func (r *repository) FindOrganisations(filter *types.OrganisationFilter) ([]*types.Organisation, error) { db := factory.Database.MustGet() rval := make([]*types.Organisation, 0) params := make([]interface{}, 0) @@ -44,30 +47,30 @@ func (r organisation) Find(ctx context.Context, filter *types.OrganisationFilter sql += " ORDER BY name ASC" - return rval, db.With(ctx).Select(&rval, sql, params...) + return rval, db.With(r.ctx).Select(&rval, sql, params...) } -func (r organisation) Create(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) { +func (r *repository) CreateOrganisation(mod *types.Organisation) (*types.Organisation, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() - return mod, factory.Database.MustGet().With(ctx).Insert("organisations", mod) + return mod, factory.Database.MustGet().With(r.ctx).Insert("organisations", mod) } -func (r organisation) Update(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) { +func (r *repository) UpdateOrganisation(mod *types.Organisation) (*types.Organisation, error) { mod.UpdatedAt = timeNowPtr() - return mod, factory.Database.MustGet().With(ctx).Replace("organisations", mod) + return mod, factory.Database.MustGet().With(r.ctx).Replace("organisations", mod) } -func (r organisation) Archive(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "organisations", "archived_at", time.Now(), id) +func (r *repository) ArchiveOrganisationByID(id uint64) error { + return simpleUpdate(r.ctx, "organisations", "archived_at", time.Now(), id) } -func (r organisation) Unarchive(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "organisations", "archived_at", nil, id) +func (r *repository) UnarchiveOrganisationByID(id uint64) error { + return simpleUpdate(r.ctx, "organisations", "archived_at", nil, id) } -func (r organisation) Delete(ctx context.Context, id uint64) error { - return simpleDelete(ctx, "organisations", id) +func (r *repository) DeleteOrganisationByID(id uint64) error { + return simpleDelete(r.ctx, "organisations", id) } diff --git a/sam/repository/reaction.go b/sam/repository/reaction.go index 435c81a47..0b0d3e526 100644 --- a/sam/repository/reaction.go +++ b/sam/repository/reaction.go @@ -1,34 +1,34 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "github.com/titpetric/factory" "time" ) +type ( + Reaction interface { + FindReactionByID(id uint64) (*types.Reaction, error) + FindReactionsByRange(channelID, fromReactionID, toReactionID uint64) ([]*types.Reaction, error) + CreateReaction(mod *types.Reaction) (*types.Reaction, error) + DeleteReactionByID(id uint64) error + } +) + const ( ErrReactionNotFound = repositoryError("ReactionNotFound") ) -type ( - reaction struct{} -) - -func Reaction() reaction { - return reaction{} -} - -func (r reaction) FindByID(ctx context.Context, id uint64) (*types.Reaction, error) { +func (r *repository) FindReactionByID(id uint64) (*types.Reaction, error) { db := factory.Database.MustGet() sql := "SELECT * FROM reactions WHERE id = ?" mod := &types.Reaction{} - return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrReactionNotFound) + return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrReactionNotFound) } -func (r reaction) FindByRange(ctx context.Context, channelID, fromReactionID, toReactionID uint64) ([]*types.Reaction, error) { +func (r *repository) FindReactionsByRange(channelID, fromReactionID, toReactionID uint64) ([]*types.Reaction, error) { db := factory.Database.MustGet() rval := make([]*types.Reaction, 0) sql := ` @@ -37,18 +37,17 @@ func (r reaction) FindByRange(ctx context.Context, channelID, fromReactionID, to WHERE rel_reaction BETWEEN ? AND ? AND rel_channel = ?` - return rval, db.With(ctx).Select(&rval, sql, fromReactionID, toReactionID, channelID) + return rval, db.With(r.ctx).Select(&rval, sql, fromReactionID, toReactionID, channelID) } -func (r reaction) Create(ctx context.Context, mod *types.Reaction) (*types.Reaction, error) { +func (r *repository) CreateReaction(mod *types.Reaction) (*types.Reaction, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() - return mod, factory.Database.MustGet().With(ctx).Insert("reactions", mod) + return mod, factory.Database.MustGet().With(r.ctx).Insert("reactions", mod) } -func (r reaction) Delete(ctx context.Context, id uint64) error { +func (r *repository) DeleteReactionByID(id uint64) error { db := factory.Database.MustGet() - - return exec(db.With(ctx).Exec("DELETE FROM reactions WHERE id = ?", id)) + return exec(db.With(r.ctx).Exec("DELETE FROM reactions WHERE id = ?", id)) } diff --git a/sam/repository/repository.go b/sam/repository/repository.go new file mode 100644 index 000000000..b546d3521 --- /dev/null +++ b/sam/repository/repository.go @@ -0,0 +1,91 @@ +package repository + +import ( + "context" + "github.com/titpetric/factory" +) + +type ( + repository struct { + ctx context.Context + + // Current transaction + tx *factory.DB + } + + Transactionable interface { + BeginWith(ctx context.Context, callback BeginCallback) error + Begin() error + Rollback() error + Commit() error + } + + Contextable interface { + WithCtx(ctx context.Context) Interfaces + } + + Interfaces interface { + Transactionable + Contextable + + Attachment + Channel + Message + Organisation + Reaction + Team + User + } + + BeginCallback func(r Interfaces) error +) + +func New() *repository { + return &repository{ctx: context.Background()} +} + +func (r *repository) WithCtx(ctx context.Context) Interfaces { + return &repository{ctx: ctx, tx: r.tx} +} + +func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error { + tx := r.tx + if tx == nil { + tx = factory.Database.MustGet().With(ctx) + } + + txr := &repository{ctx: ctx, tx: tx} + + if err := txr.Begin(); err != nil { + return err + } + + if err := callback(txr); err != nil { + if err := txr.Rollback(); err != nil { + return err + } + + return err + } + + return txr.Commit() +} + +func (r *repository) Begin() error { + // @todo implementation + return r.tx.Begin() +} + +func (r *repository) Commit() error { + // @todo implementation + return r.tx.Commit() +} + +func (r *repository) Rollback() error { + // @todo implementation + return r.tx.Rollback() +} + +func (r *repository) db() *factory.DB { + return r.tx.With(r.ctx) +} diff --git a/sam/repository/team.go b/sam/repository/team.go index e85c2b9eb..382352469 100644 --- a/sam/repository/team.go +++ b/sam/repository/team.go @@ -1,27 +1,32 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "github.com/titpetric/factory" "time" ) +type ( + Team interface { + FindTeamByID(id uint64) (*types.Team, error) + FindTeams(filter *types.TeamFilter) ([]*types.Team, error) + CreateTeam(mod *types.Team) (*types.Team, error) + UpdateTeam(mod *types.Team) (*types.Team, error) + ArchiveTeamByID(id uint64) error + UnarchiveTeamByID(id uint64) error + DeleteTeamByID(id uint64) error + MergeTeamByID(id, targetTeamID uint64) error + MoveTeamByID(id, targetOrganisationID uint64) error + } +) + const ( sqlTeamScope = "deleted_at IS NULL AND archived_at IS NULL" ErrTeamNotFound = repositoryError("TeamNotFound") ) -type ( - team struct{} -) - -func Team() team { - return team{} -} - -func (r team) FindByID(ctx context.Context, id uint64) (*types.Team, error) { +func (r *repository) FindTeamByID(id uint64) (*types.Team, error) { db := factory.Database.MustGet() sql := "SELECT * FROM teams WHERE id = ? AND " + sqlTeamScope mod := &types.Team{} @@ -29,7 +34,7 @@ func (r team) FindByID(ctx context.Context, id uint64) (*types.Team, error) { return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrTeamNotFound) } -func (r team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team, error) { +func (r *repository) FindTeams(filter *types.TeamFilter) ([]*types.Team, error) { db := factory.Database.MustGet() rval := make([]*types.Team, 0) params := make([]interface{}, 0) @@ -45,38 +50,38 @@ func (r team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team sql += " ORDER BY name ASC" - return rval, db.With(ctx).Select(&rval, sql, params...) + return rval, db.With(r.ctx).Select(&rval, sql, params...) } -func (r team) Create(ctx context.Context, mod *types.Team) (*types.Team, error) { +func (r *repository) CreateTeam(mod *types.Team) (*types.Team, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() - return mod, factory.Database.MustGet().With(ctx).Insert("teams", mod) + return mod, factory.Database.MustGet().With(r.ctx).Insert("teams", mod) } -func (r team) Update(ctx context.Context, mod *types.Team) (*types.Team, error) { +func (r *repository) UpdateTeam(mod *types.Team) (*types.Team, error) { mod.UpdatedAt = timeNowPtr() - return mod, factory.Database.MustGet().With(ctx).Replace("teams", mod) + return mod, factory.Database.MustGet().With(r.ctx).Replace("teams", mod) } -func (r team) Archive(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "teams", "archived_at", time.Now(), id) +func (r *repository) ArchiveTeamByID(id uint64) error { + return simpleUpdate(r.ctx, "teams", "archived_at", time.Now(), id) } -func (r team) Unarchive(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "teams", "archived_at", nil, id) +func (r *repository) UnarchiveTeamByID(id uint64) error { + return simpleUpdate(r.ctx, "teams", "archived_at", nil, id) } -func (r team) Delete(ctx context.Context, id uint64) error { - return simpleDelete(ctx, "teams", id) +func (r *repository) DeleteTeamByID(id uint64) error { + return simpleDelete(r.ctx, "teams", id) } -func (r team) Merge(ctx context.Context, id, targetTeamID uint64) error { +func (r *repository) MergeTeamByID(id, targetTeamID uint64) error { return ErrNotImplemented } -func (r team) Move(ctx context.Context, id, targetOrganisationID uint64) error { +func (r *repository) MoveTeamByID(id, targetOrganisationID uint64) error { return ErrNotImplemented } diff --git a/sam/repository/user.go b/sam/repository/user.go index 345e29013..c3d0de67d 100644 --- a/sam/repository/user.go +++ b/sam/repository/user.go @@ -1,12 +1,24 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "github.com/titpetric/factory" "time" ) +type ( + User interface { + FindUserByUsername(username string) (*types.User, error) + FindUserByID(id uint64) (*types.User, error) + FindUsers(filter *types.UserFilter) ([]*types.User, error) + CreateUser(mod *types.User) (*types.User, error) + UpdateUser(mod *types.User) (*types.User, error) + SuspendUserByID(id uint64) error + UnsuspendUserByID(id uint64) error + DeleteUserByID(id uint64) error + } +) + const ( sqlUserScope = "suspended_at IS NULL AND deleted_at IS NULL" sqlUserSelect = "SELECT * FROM users WHERE " + sqlUserScope @@ -14,32 +26,21 @@ const ( ErrUserNotFound = repositoryError("UserNotFound") ) -type ( - user struct{} -) - -func User() user { - return user{} -} - -func (r user) FindByUsername(ctx context.Context, username string) (*types.User, error) { - db := factory.Database.MustGet() +func (r *repository) FindUserByUsername(username string) (*types.User, error) { sql := "SELECT * FROM users WHERE username = ? AND " + sqlUserScope mod := &types.User{} - return mod, isFound(db.Get(mod, sql, username), mod.ID > 0, ErrUserNotFound) + return mod, isFound(r.db().Get(mod, sql, username), mod.ID > 0, ErrUserNotFound) } -func (r user) FindByID(ctx context.Context, id uint64) (*types.User, error) { - db := factory.Database.MustGet() +func (r *repository) FindUserByID(id uint64) (*types.User, error) { sql := "SELECT * FROM users WHERE id = ? AND " + sqlUserScope mod := &types.User{} - return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrUserNotFound) + return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrUserNotFound) } -func (r user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) { - db := factory.Database.MustGet() +func (r *repository) FindUsers(filter *types.UserFilter) ([]*types.User, error) { rval := make([]*types.User, 0) params := make([]interface{}, 0) sql := "SELECT * FROM users WHERE " + sqlUserScope @@ -58,32 +59,32 @@ func (r user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User sql += " ORDER BY username ASC" - return rval, db.With(ctx).Select(&rval, sql, params...) + return rval, r.db().Select(&rval, sql, params...) } -func (r user) Create(ctx context.Context, mod *types.User) (*types.User, error) { +func (r *repository) CreateUser(mod *types.User) (*types.User, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() mod.Meta = coalesceJson(mod.Meta, []byte("{}")) - return mod, factory.Database.MustGet().With(ctx).Insert("users", mod) + return mod, r.db().Insert("users", mod) } -func (r user) Update(ctx context.Context, mod *types.User) (*types.User, error) { +func (r *repository) UpdateUser(mod *types.User) (*types.User, error) { mod.UpdatedAt = timeNowPtr() mod.Meta = coalesceJson(mod.Meta, []byte("{}")) - return mod, factory.Database.MustGet().With(ctx).Replace("users", mod) + return mod, r.db().Replace("users", mod) } -func (r user) Suspend(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "users", "suspend_at", time.Now(), id) +func (r *repository) SuspendUserByID(id uint64) error { + return simpleUpdate(r.ctx, "users", "suspend_at", time.Now(), id) } -func (r user) Unsuspend(ctx context.Context, id uint64) error { - return simpleUpdate(ctx, "users", "suspend_at", nil, id) +func (r *repository) UnsuspendUserByID(id uint64) error { + return simpleUpdate(r.ctx, "users", "suspend_at", nil, id) } -func (r user) Delete(ctx context.Context, id uint64) error { - return simpleDelete(ctx, "users", id) +func (r *repository) DeleteUserByID(id uint64) error { + return simpleDelete(r.ctx, "users", id) } diff --git a/sam/service/channel.go b/sam/service/channel.go index 7a0d68ee5..2235cbce7 100644 --- a/sam/service/channel.go +++ b/sam/service/channel.go @@ -8,41 +8,56 @@ import ( type ( channel struct { - repository channelRepository + rpo channelRepository + // + //sec struct { + // ch channelSecurity + //} } channelRepository interface { - FindByID(ctx context.Context, channelID uint64) (*types.Channel, error) - Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error) - - Create(ctx context.Context, channel *types.Channel) (*types.Channel, error) - Update(ctx context.Context, channel *types.Channel) (*types.Channel, error) - - deleter - archiver + repository.Transactionable + repository.Channel } + + //channelSecurity interface { + // CanRead(ctx context.Context, ch *types.Channel) bool + //} ) func Channel() *channel { - return &channel{repository: repository.Channel()} + var svc = &channel{} + + svc.rpo = repository.New() + //svc.sec.ch = ChannelSecurity(svc.rpo) + + return svc } -func (svc channel) FindByID(ctx context.Context, id uint64) (*types.Channel, error) { - // @todo: permission check if current user can read channel - return svc.repository.FindByID(ctx, id) +func (svc channel) FindByID(ctx context.Context, id uint64) (ch *types.Channel, err error) { + ch, err = svc.rpo.FindChannelByID(id) + if err != nil { + return + } + + //if !svc.sec.ch.CanRead(ch) { + // return nil, errors.New("Not allowed to access channel") + //} + + return } func (svc channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error) { // @todo: permission check to return only channels that channel has access to // @todo: actual searching not just a full select - return svc.repository.Find(ctx, filter) + return svc.rpo.FindChannels(filter) } func (svc channel) Create(ctx context.Context, mod *types.Channel) (*types.Channel, error) { // @todo: topic channelEvent/log entry // @todo: channel name cmessage/log entry // @todo: permission check if channel can add channel - return svc.repository.Create(ctx, mod) + return svc.rpo.CreateChannel(mod) } func (svc channel) Update(ctx context.Context, mod *types.Channel) (*types.Channel, error) { @@ -54,26 +69,58 @@ func (svc channel) Update(ctx context.Context, mod *types.Channel) (*types.Chann // @todo: handle channel movinga // @todo: handle channel archiving - return svc.repository.Update(ctx, mod) + return svc.rpo.UpdateChannel(mod) } func (svc channel) Delete(ctx context.Context, id uint64) error { // @todo: make history unavailable // @todo: notify users that channel has been removed (remove from web UI) // @todo: permissions check if current user can remove channel - return svc.repository.Delete(ctx, id) + return svc.rpo.DeleteChannelByID(id) } func (svc channel) Archive(ctx context.Context, id uint64) error { // @todo: make history unavailable // @todo: notify users that channel has been removed (remove from web UI) // @todo: permissions check if current user can remove channel - return svc.repository.Archive(ctx, id) + return svc.rpo.ArchiveChannelByID(id) } func (svc channel) Unarchive(ctx context.Context, id uint64) error { // @todo: make history unavailable // @todo: notify users that channel has been removed (remove from web UI) // @todo: permissions check if current user can remove channel - return svc.repository.Unarchive(ctx, id) + return svc.rpo.UnarchiveChannelByID(id) } + +//// @todo temp location, move this somewhere else +//type ( +// nativeChannelSec struct { +// rpo struct { +// ch nativeChannelSecChRepo +// } +// } +// +// nativeChannelSecChRepo interface { +// FindMember(ctx context.Context, channelId uint64, userId uint64) (*types.User, error) +// } +//) +// +//func ChannelSecurity(chRpo nativeChannelSecChRepo) channelSecurity { +// var sec = &nativeChannelSec{} +// +// sec.rpo.ch = chRpo +// +// return sec +//} +// +//// Current user can read the channel if he is a member +//func (sec nativeChannelSec) CanRead(ctx context.Context, ch *types.Channel) bool { +// // @todo check if channel is public? +// +// var currentUserID = auth.GetIdentityFromContext(ctx).Identity() +// +// user, err := sec.rpo.FindMember(ch.ID, currentUserID) +// +// return err != nil && user.Valid() +//} diff --git a/sam/service/message.go b/sam/service/message.go index 09ef62e32..f7a7ecb21 100644 --- a/sam/service/message.go +++ b/sam/service/message.go @@ -9,42 +9,19 @@ import ( type ( message struct { - repository struct { - message messageRepository - reaction messageReactionRepository - attachment messageAttachmentRepository - } + rpo messageRepository } messageRepository interface { - FindByID(ctx context.Context, messageID uint64) (*types.Message, error) - Find(ctx context.Context, filter *types.MessageFilter) ([]*types.Message, error) - - Create(ctx context.Context, message *types.Message) (*types.Message, error) - Update(ctx context.Context, message *types.Message) (*types.Message, error) - - deleter - } - - messageReactionRepository interface { - FindByID(ctx context.Context, reactionID uint64) (*types.Reaction, error) - Create(ctx context.Context, reaction *types.Reaction) (*types.Reaction, error) - Delete(ctx context.Context, reactionID uint64) error - } - - messageAttachmentRepository interface { - FindByID(ctx context.Context, attachmentID uint64) (*types.Attachment, error) - Create(ctx context.Context, attachment *types.Attachment) (*types.Attachment, error) - Delete(ctx context.Context, attachmentID uint64) error + repository.Transactionable + repository.Message + repository.Reaction + repository.Attachment } ) func Message() *message { - m := &message{} - m.repository.message = repository.Message() - m.repository.reaction = repository.Reaction() - m.repository.attachment = repository.Attachment() - + m := &message{rpo: repository.New()} return m } @@ -56,7 +33,7 @@ func (svc message) Find(ctx context.Context, filter *types.MessageFilter) ([]*ty _ = currentUserID _ = filter.ChannelID - return svc.repository.message.Find(ctx, filter) + return svc.rpo.FindMessages(filter) } func (svc message) Create(ctx context.Context, mod *types.Message) (*types.Message, error) { @@ -66,7 +43,7 @@ func (svc message) Create(ctx context.Context, mod *types.Message) (*types.Messa // @todo verify if current user can access & write to this channel _ = currentUserID - return svc.repository.message.Create(ctx, mod) + return svc.rpo.CreateMessage(mod) } func (svc message) Update(ctx context.Context, mod *types.Message) (*types.Message, error) { @@ -80,7 +57,7 @@ func (svc message) Update(ctx context.Context, mod *types.Message) (*types.Messa // @todo verify ownership - return svc.repository.message.Update(ctx, mod) + return svc.rpo.UpdateMessage(mod) } func (svc message) Delete(ctx context.Context, id uint64) error { @@ -94,7 +71,7 @@ func (svc message) Delete(ctx context.Context, id uint64) error { // @todo verify ownership - return svc.repository.message.Delete(ctx, id) + return svc.rpo.DeleteMessageByID(id) } func (svc message) React(ctx context.Context, messageID uint64, reaction string) error { @@ -113,7 +90,7 @@ func (svc message) React(ctx context.Context, messageID uint64, reaction string) Reaction: reaction, } - if _, err := svc.repository.reaction.Create(ctx, r); err != nil { + if _, err := svc.rpo.CreateReaction(r); err != nil { return err } @@ -130,7 +107,7 @@ func (svc message) Unreact(ctx context.Context, messageID uint64, reaction strin // @todo load reaction and verify ownership var r *types.Reaction - return svc.repository.reaction.Delete(ctx, r.ID) + return svc.rpo.DeleteReactionByID(r.ID) } func (svc message) Pin(ctx context.Context, messageID uint64) error { @@ -194,5 +171,5 @@ func (svc message) Detach(ctx context.Context, attachmentID uint64) error { // @todo verify if current user can remove this attachment - return svc.repository.attachment.Delete(ctx, attachmentID) + return svc.rpo.DeleteAttachmentByID(attachmentID) } diff --git a/sam/service/organisation.go b/sam/service/organisation.go index 239ea7c9d..317644162 100644 --- a/sam/service/organisation.go +++ b/sam/service/organisation.go @@ -8,67 +8,61 @@ import ( type ( organisation struct { - repository organisationRepository + rpo organisationRepository } organisationRepository interface { - FindByID(ctx context.Context, organisationID uint64) (*types.Organisation, error) - Find(ctx context.Context, filter *types.OrganisationFilter) ([]*types.Organisation, error) - - Create(ctx context.Context, organisation *types.Organisation) (*types.Organisation, error) - Update(ctx context.Context, organisation *types.Organisation) (*types.Organisation, error) - - deleter - archiver + repository.Transactionable + repository.Organisation } ) func Organisation() *organisation { - return &organisation{repository: repository.Organisation()} + return &organisation{rpo: repository.New()} } func (svc organisation) FindByID(ctx context.Context, id uint64) (*types.Organisation, error) { // @todo: permission check if current user can read organisation - return svc.repository.FindByID(ctx, id) + return svc.rpo.FindOrganisationByID(id) } func (svc organisation) Find(ctx context.Context, filter *types.OrganisationFilter) ([]*types.Organisation, error) { // @todo: permission check to return only organisations that organisation has access to // @todo: actual searching not just a full select - return svc.repository.Find(ctx, filter) + return svc.rpo.FindOrganisations(filter) } func (svc organisation) Create(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) { // @todo: permission check if current user can add/edit organisation // @todo: make sure archived & deleted entries can not be edited - return svc.repository.Create(ctx, mod) + return svc.rpo.CreateOrganisation(mod) } func (svc organisation) Update(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) { // @todo: permission check if current user can add/edit organisation // @todo: make sure archived & deleted entries can not be edited - return svc.repository.Update(ctx, mod) + return svc.rpo.UpdateOrganisation(mod) } func (svc organisation) Delete(ctx context.Context, id uint64) error { // @todo: permissions check if current user can remove organisation // @todo: make history unavailable // @todo: notify users that organisation has been removed (remove from web UI) - return svc.repository.Delete(ctx, id) + return svc.rpo.DeleteOrganisationByID(id) } func (svc organisation) Archive(ctx context.Context, id uint64) error { // @todo: make history unavailable // @todo: notify users that organisation has been removed (remove from web UI) // @todo: permissions check if current user can archive organisation - return svc.repository.Archive(ctx, id) + return svc.rpo.ArchiveOrganisationByID(id) } func (svc organisation) Unarchive(ctx context.Context, id uint64) error { // @todo: make history unavailable // @todo: notify users that organisation has been removed (remove from web UI) // @todo: permissions check if current user can unarchive organisation - return svc.repository.Unarchive(ctx, id) + return svc.rpo.UnarchiveOrganisationByID(id) } diff --git a/sam/service/service.go b/sam/service/service.go deleted file mode 100644 index 7674dfb79..000000000 --- a/sam/service/service.go +++ /dev/null @@ -1,20 +0,0 @@ -package service - -import ( - "context" -) - -type ( - suspender interface { - Suspend(ctx context.Context, ID uint64) error - Unsuspend(ctx context.Context, ID uint64) error - } - archiver interface { - Archive(ctx context.Context, ID uint64) error - Unarchive(ctx context.Context, ID uint64) error - } - - deleter interface { - Delete(ctx context.Context, ID uint64) error - } -) diff --git a/sam/service/team.go b/sam/service/team.go index 9e8f299bf..6ca2d8c85 100644 --- a/sam/service/team.go +++ b/sam/service/team.go @@ -8,78 +8,69 @@ import ( type ( team struct { - repository teamRepository + rpo teamRepository } teamRepository interface { - FindByID(ctx context.Context, teamID uint64) (*types.Team, error) - Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team, error) - - Create(ctx context.Context, team *types.Team) (*types.Team, error) - Update(ctx context.Context, team *types.Team) (*types.Team, error) - - Merge(ctx context.Context, teamID, targetTeamID uint64) error - Move(ctx context.Context, teamID, organisationID uint64) error - - deleter - archiver + repository.Transactionable + repository.Team } ) func Team() *team { - return &team{repository: repository.Team()} + return &team{rpo: repository.New()} } func (svc team) FindByID(ctx context.Context, id uint64) (*types.Team, error) { // @todo: permission check if current user has access to this team - return svc.repository.FindByID(ctx, id) + return svc.rpo.FindTeamByID(id) } func (svc team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team, error) { // @todo: permission check to return only teams that current user has access to - return svc.repository.Find(ctx, filter) + return svc.rpo.FindTeams(filter) } func (svc team) Create(ctx context.Context, mod *types.Team) (*types.Team, error) { // @todo: permission check if current user can add/edit team - return svc.repository.Create(ctx, mod) + return svc.rpo.CreateTeam(mod) } func (svc team) Update(ctx context.Context, mod *types.Team) (*types.Team, error) { // @todo: permission check if current user can add/edit team // @todo: make sure archived & deleted entries can not be edited - return svc.repository.Update(ctx, mod) + return svc.rpo.UpdateTeam(mod) } func (svc team) Delete(ctx context.Context, id uint64) error { // @todo: make history unavailable // @todo: notify users that team has been removed (remove from web UI) // @todo: permissions check if current user can remove team - return svc.repository.Delete(ctx, id) + return svc.rpo.DeleteTeamByID(id) } func (svc team) Archive(ctx context.Context, id uint64) error { // @todo: make history unavailable // @todo: notify users that team has been removed (remove from web UI) // @todo: permissions check if current user can remove team - return svc.repository.Archive(ctx, id) + return svc.rpo.ArchiveTeamByID(id) } func (svc team) Unarchive(ctx context.Context, id uint64) error { // @todo: permissions check if current user can unarchive team // @todo: make history accessible // @todo: notify users that team has been unarchived - return svc.repository.Unarchive(ctx, id) + return svc.rpo.UnarchiveTeamByID(id) } func (svc team) Merge(ctx context.Context, id, targetTeamID uint64) error { // @todo: permission check if current user can merge team - return svc.repository.Merge(ctx, id, targetTeamID) + return svc.rpo.MergeTeamByID(id, targetTeamID) } func (svc team) Move(ctx context.Context, id, targetOrganisationID uint64) error { // @todo: permission check if current user can move team to another organisation - return svc.repository.Move(ctx, id, targetOrganisationID) + return svc.rpo.MoveTeamByID(id, targetOrganisationID) } diff --git a/sam/service/user.go b/sam/service/user.go index 91a6dbcd4..076aaf9f1 100644 --- a/sam/service/user.go +++ b/sam/service/user.go @@ -14,28 +14,22 @@ const ( type ( user struct { - repository userRepository + rpo userRepository } userRepository interface { - FindByUsername(ctx context.Context, username string) (*types.User, error) - FindByID(ctx context.Context, userID uint64) (*types.User, error) - Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) - - Create(ctx context.Context, user *types.User) (*types.User, error) - Update(ctx context.Context, user *types.User) (*types.User, error) - - deleter - suspender + repository.Transactionable + repository.Contextable + repository.User } ) func User() *user { - return &user{repository: repository.User()} + return &user{rpo: repository.New()} } func (svc user) ValidateCredentials(ctx context.Context, username, password string) (*types.User, error) { - user, err := svc.repository.FindByUsername(ctx, username) + user, err := svc.rpo.FindUserByUsername(username) if err != nil { return nil, err } @@ -52,19 +46,26 @@ func (svc user) ValidateCredentials(ctx context.Context, username, password stri } func (svc user) FindByID(ctx context.Context, id uint64) (*types.User, error) { - return svc.repository.FindByID(ctx, id) + return svc.rpo.WithCtx(ctx).FindUserByID(id) } func (svc user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) { - return svc.repository.Find(ctx, filter) + return svc.rpo.FindUsers(filter) } -func (svc user) Create(ctx context.Context, mod *types.User) (*types.User, error) { - return svc.repository.Create(ctx, mod) +func (svc user) Create(ctx context.Context, input *types.User) (new *types.User, err error) { + // no real need for tx here, just presenting the capabilities + return new, svc.rpo.BeginWith(ctx, func(r repository.Interfaces) error { + if new, err = r.CreateUser(input); err != nil { + return err + } + + return nil + }) } func (svc user) Update(ctx context.Context, mod *types.User) (*types.User, error) { - return svc.repository.Update(ctx, mod) + return svc.rpo.UpdateUser(mod) } func (svc user) validatePassword(user *types.User, password string) bool { @@ -89,17 +90,17 @@ func (svc user) canLogin(u *types.User) bool { func (svc user) Delete(ctx context.Context, id uint64) error { // @todo: permissions check if current user can delete this user // @todo: notify users that user has been removed (remove from web UI) - return svc.repository.Delete(ctx, id) + return svc.rpo.DeleteUserByID(id) } func (svc user) Suspend(ctx context.Context, id uint64) error { // @todo: permissions check if current user can suspend this user // @todo: notify users that user has been supsended (remove from web UI) - return svc.repository.Suspend(ctx, id) + return svc.rpo.SuspendUserByID(id) } func (svc user) Unsuspend(ctx context.Context, id uint64) error { // @todo: permissions check if current user can unsuspend this user // @todo: notify users that user has been unsuspended - return svc.repository.Unsuspend(ctx, id) + return svc.rpo.UnsuspendUserByID(id) }