3
0

Refactor repository, add support for transactions

This commit is contained in:
Denis Arh
2018-07-29 23:31:23 +02:00
parent 463afa77d0
commit baae430adc
17 changed files with 406 additions and 308 deletions

View File

@@ -1,35 +1,38 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"github.com/titpetric/factory"
"time"
)
type (
Attachment interface {
FindAttachmentByID(id uint64) (*types.Attachment, error)
FindAttachmentByRange(channelID, fromAttachmentID, toAttachmentID uint64) ([]*types.Attachment, error)
CreateAttachment(mod *types.Attachment) (*types.Attachment, error)
UpdateAttachment(mod *types.Attachment) (*types.Attachment, error)
DeleteAttachmentByID(id uint64) error
}
)
const (
sqlAttachmentScope = "deleted_at IS NULL"
ErrAttachmentNotFound = repositoryError("AttachmentNotFound")
)
type (
attachment struct{}
)
var _ Attachment = &repository{}
func Attachment() attachment {
return attachment{}
}
func (r attachment) FindByID(ctx context.Context, id uint64) (*types.Attachment, error) {
func (r *repository) FindAttachmentByID(id uint64) (*types.Attachment, error) {
db := factory.Database.MustGet()
sql := "SELECT * FROM attachments WHERE id = ? AND " + sqlAttachmentScope
mod := &types.Attachment{}
return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrAttachmentNotFound)
return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrAttachmentNotFound)
}
func (r attachment) FindByRange(ctx context.Context, channelID, fromAttachmentID, toAttachmentID uint64) ([]*types.Attachment, error) {
func (r *repository) FindAttachmentByRange(channelID, fromAttachmentID, toAttachmentID uint64) ([]*types.Attachment, error) {
db := factory.Database.MustGet()
rval := make([]*types.Attachment, 0)
@@ -40,25 +43,25 @@ func (r attachment) FindByRange(ctx context.Context, channelID, fromAttachmentID
AND rel_channel = ?
AND deleted_at IS NULL`
return rval, db.With(ctx).Select(&rval, sql, fromAttachmentID, toAttachmentID, channelID)
return rval, db.With(r.ctx).Select(&rval, sql, fromAttachmentID, toAttachmentID, channelID)
}
func (r attachment) Create(ctx context.Context, mod *types.Attachment) (*types.Attachment, error) {
func (r *repository) CreateAttachment(mod *types.Attachment) (*types.Attachment, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
mod.Attachment = coalesceJson(mod.Attachment, []byte("{}"))
return mod, factory.Database.MustGet().With(ctx).Insert("attachments", mod)
return mod, factory.Database.MustGet().With(r.ctx).Insert("attachments", mod)
}
func (r attachment) Update(ctx context.Context, mod *types.Attachment) (*types.Attachment, error) {
func (r *repository) UpdateAttachment(mod *types.Attachment) (*types.Attachment, error) {
mod.UpdatedAt = timeNowPtr()
mod.Attachment = coalesceJson(mod.Attachment, []byte("{}"))
return mod, factory.Database.MustGet().With(ctx).Replace("attachments", mod)
return mod, factory.Database.MustGet().With(r.ctx).Replace("attachments", mod)
}
func (r attachment) Delete(ctx context.Context, id uint64) error {
return simpleDelete(ctx, "attachments", id)
func (r *repository) DeleteAttachmentByID(id uint64) error {
return simpleDelete(r.ctx, "attachments", id)
}

View File

@@ -22,7 +22,7 @@ func TestAttachment(t *testing.T) {
att.ChannelID = 1
att, err = rpo.Create(ctx, att)
att, err = rpo.CreateAttachment(ctx, att)
must(t, err)
if att.ChannelID != 1 {
t.Fatal("Changes were not stored")
@@ -30,23 +30,23 @@ func TestAttachment(t *testing.T) {
att.ChannelID = 2
att, err = rpo.Update(ctx, att)
att, err = rpo.UpdateAttachment(ctx, att)
must(t, err)
if att.ChannelID != 2 {
t.Fatal("Changes were not stored")
}
att, err = rpo.FindByID(ctx, att.ID)
att, err = rpo.FindAttachmentByID(ctx, att.ID)
must(t, err)
if att.ChannelID != 2 {
t.Fatal("Changes were not stored")
}
aa, err = rpo.FindByRange(ctx, 2, 0, att.ID)
aa, err = rpo.FindAttachmentByRange(ctx, 2, 0, att.ID)
must(t, err)
if len(aa) == 0 {
t.Fatal("No results found")
}
must(t, rpo.Delete(ctx, att.ID))
must(t, rpo.DeleteAttachmentByID(ctx, att.ID))
}

View File

@@ -1,35 +1,40 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"github.com/titpetric/factory"
"time"
)
type (
Channel interface {
FindChannelByID(id uint64) (*types.Channel, error)
FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error)
CreateChannel(mod *types.Channel) (*types.Channel, error)
UpdateChannel(mod *types.Channel) (*types.Channel, error)
AddChannelMember(channelID, userID uint64) error
RemoveChannelMember(channelID, userID uint64) error
ArchiveChannelByID(id uint64) error
UnarchiveChannelByID(id uint64) error
DeleteChannelByID(id uint64) error
}
)
const (
sqlChannelScope = "deleted_at IS NULL AND archived_at IS NULL"
ErrChannelNotFound = repositoryError("ChannelNotFound")
)
type (
channel struct{}
)
func Channel() channel {
return channel{}
}
func (r channel) FindByID(ctx context.Context, id uint64) (*types.Channel, error) {
func (r *repository) FindChannelByID(id uint64) (*types.Channel, error) {
db := factory.Database.MustGet()
mod := &types.Channel{}
sql := "SELECT * FROM channels WHERE id = ? AND " + sqlChannelScope
return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrChannelNotFound)
return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrChannelNotFound)
}
func (r channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error) {
func (r *repository) FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error) {
db := factory.Database.MustGet()
params := make([]interface{}, 0)
rval := make([]*types.Channel, 0)
@@ -45,42 +50,42 @@ func (r channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*type
sql += " ORDER BY name ASC"
return rval, db.With(ctx).Select(&rval, sql, params...)
return rval, db.With(r.ctx).Select(&rval, sql, params...)
}
func (r channel) Create(ctx context.Context, mod *types.Channel) (*types.Channel, error) {
func (r *repository) CreateChannel(mod *types.Channel) (*types.Channel, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
return mod, factory.Database.MustGet().With(ctx).Insert("channels", mod)
return mod, factory.Database.MustGet().With(r.ctx).Insert("channels", mod)
}
func (r channel) Update(ctx context.Context, mod *types.Channel) (*types.Channel, error) {
func (r *repository) UpdateChannel(mod *types.Channel) (*types.Channel, error) {
mod.UpdatedAt = timeNowPtr()
mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
return mod, factory.Database.MustGet().With(ctx).Replace("channels", mod)
return mod, factory.Database.MustGet().With(r.ctx).Replace("channels", mod)
}
func (r channel) AddMember(ctx context.Context, channelID, userID uint64) error {
func (r *repository) AddChannelMember(channelID, userID uint64) error {
sql := `INSERT INTO channel_members (rel_channel, rel_user) VALUES (?, ?)`
return exec(factory.Database.MustGet().With(ctx).Exec(sql, channelID, userID))
return exec(factory.Database.MustGet().With(r.ctx).Exec(sql, channelID, userID))
}
func (r channel) RemoveMember(ctx context.Context, channelID, userID uint64) error {
func (r *repository) RemoveChannelMember(channelID, userID uint64) error {
sql := `DELETE FROM channel_members WHERE rel_channel = ? AND rel_user = ?`
return exec(factory.Database.MustGet().With(ctx).Exec(sql, channelID, userID))
return exec(factory.Database.MustGet().With(r.ctx).Exec(sql, channelID, userID))
}
func (r channel) Archive(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "channels", "archived_at", time.Now(), id)
func (r *repository) ArchiveChannelByID(id uint64) error {
return simpleUpdate(r.ctx, "channels", "archived_at", time.Now(), id)
}
func (r channel) Unarchive(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "channels", "archived_at", nil, id)
func (r *repository) UnarchiveChannelByID(id uint64) error {
return simpleUpdate(r.ctx, "channels", "archived_at", nil, id)
}
func (r channel) Delete(ctx context.Context, id uint64) error {
return simpleDelete(ctx, "channels", id)
func (r *repository) DeleteChannelByID(id uint64) error {
return simpleDelete(r.ctx, "channels", id)
}

View File

@@ -24,7 +24,7 @@ func TestChannel(t *testing.T) {
chn.Name = name1
chn, err = rpo.Create(ctx, chn)
chn, err = rpo.CreateChannel(ctx, chn)
must(t, err)
if chn.Name != name1 {
t.Fatal("Changes were not stored")
@@ -32,27 +32,27 @@ func TestChannel(t *testing.T) {
chn.Name = name2
chn, err = rpo.Update(ctx, chn)
chn, err = rpo.UpdateChannel(ctx, chn)
must(t, err)
if chn.Name != name2 {
t.Fatal("Changes were not stored")
}
chn, err = rpo.FindByID(ctx, chn.ID)
chn, err = rpo.FindChannelByID(ctx, chn.ID)
must(t, err)
if chn.Name != name2 {
t.Fatal("Changes were not stored")
}
cc, err = rpo.Find(ctx, &types.ChannelFilter{Query: name2})
cc, err = rpo.FindChannels(ctx, &types.ChannelFilter{Query: name2})
must(t, err)
if len(cc) == 0 {
t.Fatal("No results found")
}
must(t, rpo.Archive(ctx, chn.ID))
must(t, rpo.Unarchive(ctx, chn.ID))
must(t, rpo.Delete(ctx, chn.ID))
must(t, rpo.ArchiveChannel(ctx, chn.ID))
must(t, rpo.UnarchiveChannel(ctx, chn.ID))
must(t, rpo.DeleteChannelByID(ctx, chn.ID))
}
func TestChannelMembers(t *testing.T) {
@@ -67,13 +67,13 @@ func TestChannelMembers(t *testing.T) {
ctx := context.Background()
chn := &types.Channel{}
chn, err = rpo.Create(ctx, chn)
chn, err = rpo.CreateChannel(ctx, chn)
must(t, err)
usr := &types.User{}
usr, err = User().Create(ctx, usr)
must(t, err)
must(t, rpo.AddMember(ctx, chn.ID, usr.ID))
must(t, rpo.RemoveMember(ctx, chn.ID, usr.ID))
must(t, rpo.AddChannelMember(ctx, chn.ID, usr.ID))
must(t, rpo.RemoveChannelMember(ctx, chn.ID, usr.ID))
}

View File

@@ -1,35 +1,36 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"github.com/titpetric/factory"
"time"
)
type (
Message interface {
FindMessageByID(id uint64) (*types.Message, error)
FindMessages(filter *types.MessageFilter) ([]*types.Message, error)
CreateMessage(mod *types.Message) (*types.Message, error)
UpdateMessage(mod *types.Message) (*types.Message, error)
DeleteMessageByID(id uint64) error
}
)
const (
sqlMessageScope = "deleted_at IS NULL"
ErrMessageNotFound = repositoryError("MessageNotFound")
)
type (
message struct{}
)
func Message() message {
return message{}
}
func (r message) FindByID(ctx context.Context, id uint64) (*types.Message, error) {
func (r *repository) FindMessageByID(id uint64) (*types.Message, error) {
db := factory.Database.MustGet()
mod := &types.Message{}
sql := "SELECT id, COALESCE(type,'') AS type, message, rel_user, rel_channel, COALESCE(reply_to, 0) AS reply_to FROM messages WHERE id = ? AND " + sqlMessageScope
return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrMessageNotFound)
return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrMessageNotFound)
}
func (r message) Find(ctx context.Context, filter *types.MessageFilter) ([]*types.Message, error) {
func (r *repository) FindMessages(filter *types.MessageFilter) ([]*types.Message, error) {
db := factory.Database.MustGet()
params := make([]interface{}, 0)
rval := make([]*types.Message, 0)
@@ -65,22 +66,22 @@ func (r message) Find(ctx context.Context, filter *types.MessageFilter) ([]*type
sql += " LIMIT ? "
params = append(params, filter.Limit)
}
return rval, db.With(ctx).Select(&rval, sql, params...)
return rval, db.With(r.ctx).Select(&rval, sql, params...)
}
func (r message) Create(ctx context.Context, mod *types.Message) (*types.Message, error) {
func (r *repository) CreateMessage(mod *types.Message) (*types.Message, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
return mod, factory.Database.MustGet().With(ctx).Insert("messages", mod)
return mod, factory.Database.MustGet().With(r.ctx).Insert("messages", mod)
}
func (r message) Update(ctx context.Context, mod *types.Message) (*types.Message, error) {
func (r *repository) UpdateMessage(mod *types.Message) (*types.Message, error) {
mod.UpdatedAt = timeNowPtr()
return mod, factory.Database.MustGet().With(ctx).Replace("messages", mod)
return mod, factory.Database.MustGet().With(r.ctx).Replace("messages", mod)
}
func (r message) Delete(ctx context.Context, id uint64) error {
return simpleDelete(ctx, "messages", id)
func (r *repository) DeleteMessageByID(id uint64) error {
return simpleDelete(r.ctx, "messages", id)
}

View File

@@ -24,7 +24,7 @@ func TestMessage(t *testing.T) {
msg.Message = msg1
msg, err = rpo.Create(ctx, msg)
msg, err = rpo.CreateMessage(ctx, msg)
must(t, err)
if msg.Message != msg1 {
t.Fatal("Changes were not stored")
@@ -38,13 +38,13 @@ func TestMessage(t *testing.T) {
t.Fatal("Changes were not stored")
}
msg, err = rpo.FindByID(ctx, msg.ID)
msg, err = rpo.FindMessageByID(ctx, msg.ID)
must(t, err)
if msg.Message != msg2 {
t.Fatal("Changes were not stored")
}
mm, err = rpo.Find(ctx, &types.MessageFilter{Query: msg2})
mm, err = rpo.FindMessages(ctx, &types.MessageFilter{Query: msg2})
must(t, err)
if len(mm) == 0 {
t.Fatal("No results found")

View File

@@ -1,27 +1,30 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"github.com/titpetric/factory"
"time"
)
type (
Organisation interface {
FindOrganisationByID(id uint64) (*types.Organisation, error)
FindOrganisations(filter *types.OrganisationFilter) ([]*types.Organisation, error)
CreateOrganisation(mod *types.Organisation) (*types.Organisation, error)
UpdateOrganisation(mod *types.Organisation) (*types.Organisation, error)
ArchiveOrganisationByID(id uint64) error
UnarchiveOrganisationByID(id uint64) error
DeleteOrganisationByID(id uint64) error
}
)
const (
sqlOrganisationScope = "deleted_at IS NULL AND archived_at IS NULL"
ErrOrganisationNotFound = repositoryError("OrganisationNotFound")
)
type (
organisation struct{}
)
func Organisation() organisation {
return organisation{}
}
func (r organisation) FindByID(ctx context.Context, id uint64) (*types.Organisation, error) {
func (r *repository) FindOrganisationByID(id uint64) (*types.Organisation, error) {
db := factory.Database.MustGet()
sql := "SELECT * FROM organisations WHERE id = ? AND " + sqlOrganisationScope
mod := &types.Organisation{}
@@ -29,7 +32,7 @@ func (r organisation) FindByID(ctx context.Context, id uint64) (*types.Organisat
return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrOrganisationNotFound)
}
func (r organisation) Find(ctx context.Context, filter *types.OrganisationFilter) ([]*types.Organisation, error) {
func (r *repository) FindOrganisations(filter *types.OrganisationFilter) ([]*types.Organisation, error) {
db := factory.Database.MustGet()
rval := make([]*types.Organisation, 0)
params := make([]interface{}, 0)
@@ -44,30 +47,30 @@ func (r organisation) Find(ctx context.Context, filter *types.OrganisationFilter
sql += " ORDER BY name ASC"
return rval, db.With(ctx).Select(&rval, sql, params...)
return rval, db.With(r.ctx).Select(&rval, sql, params...)
}
func (r organisation) Create(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) {
func (r *repository) CreateOrganisation(mod *types.Organisation) (*types.Organisation, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
return mod, factory.Database.MustGet().With(ctx).Insert("organisations", mod)
return mod, factory.Database.MustGet().With(r.ctx).Insert("organisations", mod)
}
func (r organisation) Update(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) {
func (r *repository) UpdateOrganisation(mod *types.Organisation) (*types.Organisation, error) {
mod.UpdatedAt = timeNowPtr()
return mod, factory.Database.MustGet().With(ctx).Replace("organisations", mod)
return mod, factory.Database.MustGet().With(r.ctx).Replace("organisations", mod)
}
func (r organisation) Archive(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "organisations", "archived_at", time.Now(), id)
func (r *repository) ArchiveOrganisationByID(id uint64) error {
return simpleUpdate(r.ctx, "organisations", "archived_at", time.Now(), id)
}
func (r organisation) Unarchive(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "organisations", "archived_at", nil, id)
func (r *repository) UnarchiveOrganisationByID(id uint64) error {
return simpleUpdate(r.ctx, "organisations", "archived_at", nil, id)
}
func (r organisation) Delete(ctx context.Context, id uint64) error {
return simpleDelete(ctx, "organisations", id)
func (r *repository) DeleteOrganisationByID(id uint64) error {
return simpleDelete(r.ctx, "organisations", id)
}

View File

@@ -1,34 +1,34 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"github.com/titpetric/factory"
"time"
)
type (
Reaction interface {
FindReactionByID(id uint64) (*types.Reaction, error)
FindReactionsByRange(channelID, fromReactionID, toReactionID uint64) ([]*types.Reaction, error)
CreateReaction(mod *types.Reaction) (*types.Reaction, error)
DeleteReactionByID(id uint64) error
}
)
const (
ErrReactionNotFound = repositoryError("ReactionNotFound")
)
type (
reaction struct{}
)
func Reaction() reaction {
return reaction{}
}
func (r reaction) FindByID(ctx context.Context, id uint64) (*types.Reaction, error) {
func (r *repository) FindReactionByID(id uint64) (*types.Reaction, error) {
db := factory.Database.MustGet()
sql := "SELECT * FROM reactions WHERE id = ?"
mod := &types.Reaction{}
return mod, isFound(db.With(ctx).Get(mod, sql, id), mod.ID > 0, ErrReactionNotFound)
return mod, isFound(db.With(r.ctx).Get(mod, sql, id), mod.ID > 0, ErrReactionNotFound)
}
func (r reaction) FindByRange(ctx context.Context, channelID, fromReactionID, toReactionID uint64) ([]*types.Reaction, error) {
func (r *repository) FindReactionsByRange(channelID, fromReactionID, toReactionID uint64) ([]*types.Reaction, error) {
db := factory.Database.MustGet()
rval := make([]*types.Reaction, 0)
sql := `
@@ -37,18 +37,17 @@ func (r reaction) FindByRange(ctx context.Context, channelID, fromReactionID, to
WHERE rel_reaction BETWEEN ? AND ?
AND rel_channel = ?`
return rval, db.With(ctx).Select(&rval, sql, fromReactionID, toReactionID, channelID)
return rval, db.With(r.ctx).Select(&rval, sql, fromReactionID, toReactionID, channelID)
}
func (r reaction) Create(ctx context.Context, mod *types.Reaction) (*types.Reaction, error) {
func (r *repository) CreateReaction(mod *types.Reaction) (*types.Reaction, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
return mod, factory.Database.MustGet().With(ctx).Insert("reactions", mod)
return mod, factory.Database.MustGet().With(r.ctx).Insert("reactions", mod)
}
func (r reaction) Delete(ctx context.Context, id uint64) error {
func (r *repository) DeleteReactionByID(id uint64) error {
db := factory.Database.MustGet()
return exec(db.With(ctx).Exec("DELETE FROM reactions WHERE id = ?", id))
return exec(db.With(r.ctx).Exec("DELETE FROM reactions WHERE id = ?", id))
}

View File

@@ -0,0 +1,91 @@
package repository
import (
"context"
"github.com/titpetric/factory"
)
type (
repository struct {
ctx context.Context
// Current transaction
tx *factory.DB
}
Transactionable interface {
BeginWith(ctx context.Context, callback BeginCallback) error
Begin() error
Rollback() error
Commit() error
}
Contextable interface {
WithCtx(ctx context.Context) Interfaces
}
Interfaces interface {
Transactionable
Contextable
Attachment
Channel
Message
Organisation
Reaction
Team
User
}
BeginCallback func(r Interfaces) error
)
func New() *repository {
return &repository{ctx: context.Background()}
}
func (r *repository) WithCtx(ctx context.Context) Interfaces {
return &repository{ctx: ctx, tx: r.tx}
}
func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error {
tx := r.tx
if tx == nil {
tx = factory.Database.MustGet().With(ctx)
}
txr := &repository{ctx: ctx, tx: tx}
if err := txr.Begin(); err != nil {
return err
}
if err := callback(txr); err != nil {
if err := txr.Rollback(); err != nil {
return err
}
return err
}
return txr.Commit()
}
func (r *repository) Begin() error {
// @todo implementation
return r.tx.Begin()
}
func (r *repository) Commit() error {
// @todo implementation
return r.tx.Commit()
}
func (r *repository) Rollback() error {
// @todo implementation
return r.tx.Rollback()
}
func (r *repository) db() *factory.DB {
return r.tx.With(r.ctx)
}

View File

@@ -1,27 +1,32 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"github.com/titpetric/factory"
"time"
)
type (
Team interface {
FindTeamByID(id uint64) (*types.Team, error)
FindTeams(filter *types.TeamFilter) ([]*types.Team, error)
CreateTeam(mod *types.Team) (*types.Team, error)
UpdateTeam(mod *types.Team) (*types.Team, error)
ArchiveTeamByID(id uint64) error
UnarchiveTeamByID(id uint64) error
DeleteTeamByID(id uint64) error
MergeTeamByID(id, targetTeamID uint64) error
MoveTeamByID(id, targetOrganisationID uint64) error
}
)
const (
sqlTeamScope = "deleted_at IS NULL AND archived_at IS NULL"
ErrTeamNotFound = repositoryError("TeamNotFound")
)
type (
team struct{}
)
func Team() team {
return team{}
}
func (r team) FindByID(ctx context.Context, id uint64) (*types.Team, error) {
func (r *repository) FindTeamByID(id uint64) (*types.Team, error) {
db := factory.Database.MustGet()
sql := "SELECT * FROM teams WHERE id = ? AND " + sqlTeamScope
mod := &types.Team{}
@@ -29,7 +34,7 @@ func (r team) FindByID(ctx context.Context, id uint64) (*types.Team, error) {
return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrTeamNotFound)
}
func (r team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team, error) {
func (r *repository) FindTeams(filter *types.TeamFilter) ([]*types.Team, error) {
db := factory.Database.MustGet()
rval := make([]*types.Team, 0)
params := make([]interface{}, 0)
@@ -45,38 +50,38 @@ func (r team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team
sql += " ORDER BY name ASC"
return rval, db.With(ctx).Select(&rval, sql, params...)
return rval, db.With(r.ctx).Select(&rval, sql, params...)
}
func (r team) Create(ctx context.Context, mod *types.Team) (*types.Team, error) {
func (r *repository) CreateTeam(mod *types.Team) (*types.Team, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
return mod, factory.Database.MustGet().With(ctx).Insert("teams", mod)
return mod, factory.Database.MustGet().With(r.ctx).Insert("teams", mod)
}
func (r team) Update(ctx context.Context, mod *types.Team) (*types.Team, error) {
func (r *repository) UpdateTeam(mod *types.Team) (*types.Team, error) {
mod.UpdatedAt = timeNowPtr()
return mod, factory.Database.MustGet().With(ctx).Replace("teams", mod)
return mod, factory.Database.MustGet().With(r.ctx).Replace("teams", mod)
}
func (r team) Archive(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "teams", "archived_at", time.Now(), id)
func (r *repository) ArchiveTeamByID(id uint64) error {
return simpleUpdate(r.ctx, "teams", "archived_at", time.Now(), id)
}
func (r team) Unarchive(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "teams", "archived_at", nil, id)
func (r *repository) UnarchiveTeamByID(id uint64) error {
return simpleUpdate(r.ctx, "teams", "archived_at", nil, id)
}
func (r team) Delete(ctx context.Context, id uint64) error {
return simpleDelete(ctx, "teams", id)
func (r *repository) DeleteTeamByID(id uint64) error {
return simpleDelete(r.ctx, "teams", id)
}
func (r team) Merge(ctx context.Context, id, targetTeamID uint64) error {
func (r *repository) MergeTeamByID(id, targetTeamID uint64) error {
return ErrNotImplemented
}
func (r team) Move(ctx context.Context, id, targetOrganisationID uint64) error {
func (r *repository) MoveTeamByID(id, targetOrganisationID uint64) error {
return ErrNotImplemented
}

View File

@@ -1,12 +1,24 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"github.com/titpetric/factory"
"time"
)
type (
User interface {
FindUserByUsername(username string) (*types.User, error)
FindUserByID(id uint64) (*types.User, error)
FindUsers(filter *types.UserFilter) ([]*types.User, error)
CreateUser(mod *types.User) (*types.User, error)
UpdateUser(mod *types.User) (*types.User, error)
SuspendUserByID(id uint64) error
UnsuspendUserByID(id uint64) error
DeleteUserByID(id uint64) error
}
)
const (
sqlUserScope = "suspended_at IS NULL AND deleted_at IS NULL"
sqlUserSelect = "SELECT * FROM users WHERE " + sqlUserScope
@@ -14,32 +26,21 @@ const (
ErrUserNotFound = repositoryError("UserNotFound")
)
type (
user struct{}
)
func User() user {
return user{}
}
func (r user) FindByUsername(ctx context.Context, username string) (*types.User, error) {
db := factory.Database.MustGet()
func (r *repository) FindUserByUsername(username string) (*types.User, error) {
sql := "SELECT * FROM users WHERE username = ? AND " + sqlUserScope
mod := &types.User{}
return mod, isFound(db.Get(mod, sql, username), mod.ID > 0, ErrUserNotFound)
return mod, isFound(r.db().Get(mod, sql, username), mod.ID > 0, ErrUserNotFound)
}
func (r user) FindByID(ctx context.Context, id uint64) (*types.User, error) {
db := factory.Database.MustGet()
func (r *repository) FindUserByID(id uint64) (*types.User, error) {
sql := "SELECT * FROM users WHERE id = ? AND " + sqlUserScope
mod := &types.User{}
return mod, isFound(db.Get(mod, sql, id), mod.ID > 0, ErrUserNotFound)
return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrUserNotFound)
}
func (r user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) {
db := factory.Database.MustGet()
func (r *repository) FindUsers(filter *types.UserFilter) ([]*types.User, error) {
rval := make([]*types.User, 0)
params := make([]interface{}, 0)
sql := "SELECT * FROM users WHERE " + sqlUserScope
@@ -58,32 +59,32 @@ func (r user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User
sql += " ORDER BY username ASC"
return rval, db.With(ctx).Select(&rval, sql, params...)
return rval, r.db().Select(&rval, sql, params...)
}
func (r user) Create(ctx context.Context, mod *types.User) (*types.User, error) {
func (r *repository) CreateUser(mod *types.User) (*types.User, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
return mod, factory.Database.MustGet().With(ctx).Insert("users", mod)
return mod, r.db().Insert("users", mod)
}
func (r user) Update(ctx context.Context, mod *types.User) (*types.User, error) {
func (r *repository) UpdateUser(mod *types.User) (*types.User, error) {
mod.UpdatedAt = timeNowPtr()
mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
return mod, factory.Database.MustGet().With(ctx).Replace("users", mod)
return mod, r.db().Replace("users", mod)
}
func (r user) Suspend(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "users", "suspend_at", time.Now(), id)
func (r *repository) SuspendUserByID(id uint64) error {
return simpleUpdate(r.ctx, "users", "suspend_at", time.Now(), id)
}
func (r user) Unsuspend(ctx context.Context, id uint64) error {
return simpleUpdate(ctx, "users", "suspend_at", nil, id)
func (r *repository) UnsuspendUserByID(id uint64) error {
return simpleUpdate(r.ctx, "users", "suspend_at", nil, id)
}
func (r user) Delete(ctx context.Context, id uint64) error {
return simpleDelete(ctx, "users", id)
func (r *repository) DeleteUserByID(id uint64) error {
return simpleDelete(r.ctx, "users", id)
}

View File

@@ -8,41 +8,56 @@ import (
type (
channel struct {
repository channelRepository
rpo channelRepository
//
//sec struct {
// ch channelSecurity
//}
}
channelRepository interface {
FindByID(ctx context.Context, channelID uint64) (*types.Channel, error)
Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error)
Create(ctx context.Context, channel *types.Channel) (*types.Channel, error)
Update(ctx context.Context, channel *types.Channel) (*types.Channel, error)
deleter
archiver
repository.Transactionable
repository.Channel
}
//channelSecurity interface {
// CanRead(ctx context.Context, ch *types.Channel) bool
//}
)
func Channel() *channel {
return &channel{repository: repository.Channel()}
var svc = &channel{}
svc.rpo = repository.New()
//svc.sec.ch = ChannelSecurity(svc.rpo)
return svc
}
func (svc channel) FindByID(ctx context.Context, id uint64) (*types.Channel, error) {
// @todo: permission check if current user can read channel
return svc.repository.FindByID(ctx, id)
func (svc channel) FindByID(ctx context.Context, id uint64) (ch *types.Channel, err error) {
ch, err = svc.rpo.FindChannelByID(id)
if err != nil {
return
}
//if !svc.sec.ch.CanRead(ch) {
// return nil, errors.New("Not allowed to access channel")
//}
return
}
func (svc channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error) {
// @todo: permission check to return only channels that channel has access to
// @todo: actual searching not just a full select
return svc.repository.Find(ctx, filter)
return svc.rpo.FindChannels(filter)
}
func (svc channel) Create(ctx context.Context, mod *types.Channel) (*types.Channel, error) {
// @todo: topic channelEvent/log entry
// @todo: channel name cmessage/log entry
// @todo: permission check if channel can add channel
return svc.repository.Create(ctx, mod)
return svc.rpo.CreateChannel(mod)
}
func (svc channel) Update(ctx context.Context, mod *types.Channel) (*types.Channel, error) {
@@ -54,26 +69,58 @@ func (svc channel) Update(ctx context.Context, mod *types.Channel) (*types.Chann
// @todo: handle channel movinga
// @todo: handle channel archiving
return svc.repository.Update(ctx, mod)
return svc.rpo.UpdateChannel(mod)
}
func (svc channel) Delete(ctx context.Context, id uint64) error {
// @todo: make history unavailable
// @todo: notify users that channel has been removed (remove from web UI)
// @todo: permissions check if current user can remove channel
return svc.repository.Delete(ctx, id)
return svc.rpo.DeleteChannelByID(id)
}
func (svc channel) Archive(ctx context.Context, id uint64) error {
// @todo: make history unavailable
// @todo: notify users that channel has been removed (remove from web UI)
// @todo: permissions check if current user can remove channel
return svc.repository.Archive(ctx, id)
return svc.rpo.ArchiveChannelByID(id)
}
func (svc channel) Unarchive(ctx context.Context, id uint64) error {
// @todo: make history unavailable
// @todo: notify users that channel has been removed (remove from web UI)
// @todo: permissions check if current user can remove channel
return svc.repository.Unarchive(ctx, id)
return svc.rpo.UnarchiveChannelByID(id)
}
//// @todo temp location, move this somewhere else
//type (
// nativeChannelSec struct {
// rpo struct {
// ch nativeChannelSecChRepo
// }
// }
//
// nativeChannelSecChRepo interface {
// FindMember(ctx context.Context, channelId uint64, userId uint64) (*types.User, error)
// }
//)
//
//func ChannelSecurity(chRpo nativeChannelSecChRepo) channelSecurity {
// var sec = &nativeChannelSec{}
//
// sec.rpo.ch = chRpo
//
// return sec
//}
//
//// Current user can read the channel if he is a member
//func (sec nativeChannelSec) CanRead(ctx context.Context, ch *types.Channel) bool {
// // @todo check if channel is public?
//
// var currentUserID = auth.GetIdentityFromContext(ctx).Identity()
//
// user, err := sec.rpo.FindMember(ch.ID, currentUserID)
//
// return err != nil && user.Valid()
//}

View File

@@ -9,42 +9,19 @@ import (
type (
message struct {
repository struct {
message messageRepository
reaction messageReactionRepository
attachment messageAttachmentRepository
}
rpo messageRepository
}
messageRepository interface {
FindByID(ctx context.Context, messageID uint64) (*types.Message, error)
Find(ctx context.Context, filter *types.MessageFilter) ([]*types.Message, error)
Create(ctx context.Context, message *types.Message) (*types.Message, error)
Update(ctx context.Context, message *types.Message) (*types.Message, error)
deleter
}
messageReactionRepository interface {
FindByID(ctx context.Context, reactionID uint64) (*types.Reaction, error)
Create(ctx context.Context, reaction *types.Reaction) (*types.Reaction, error)
Delete(ctx context.Context, reactionID uint64) error
}
messageAttachmentRepository interface {
FindByID(ctx context.Context, attachmentID uint64) (*types.Attachment, error)
Create(ctx context.Context, attachment *types.Attachment) (*types.Attachment, error)
Delete(ctx context.Context, attachmentID uint64) error
repository.Transactionable
repository.Message
repository.Reaction
repository.Attachment
}
)
func Message() *message {
m := &message{}
m.repository.message = repository.Message()
m.repository.reaction = repository.Reaction()
m.repository.attachment = repository.Attachment()
m := &message{rpo: repository.New()}
return m
}
@@ -56,7 +33,7 @@ func (svc message) Find(ctx context.Context, filter *types.MessageFilter) ([]*ty
_ = currentUserID
_ = filter.ChannelID
return svc.repository.message.Find(ctx, filter)
return svc.rpo.FindMessages(filter)
}
func (svc message) Create(ctx context.Context, mod *types.Message) (*types.Message, error) {
@@ -66,7 +43,7 @@ func (svc message) Create(ctx context.Context, mod *types.Message) (*types.Messa
// @todo verify if current user can access & write to this channel
_ = currentUserID
return svc.repository.message.Create(ctx, mod)
return svc.rpo.CreateMessage(mod)
}
func (svc message) Update(ctx context.Context, mod *types.Message) (*types.Message, error) {
@@ -80,7 +57,7 @@ func (svc message) Update(ctx context.Context, mod *types.Message) (*types.Messa
// @todo verify ownership
return svc.repository.message.Update(ctx, mod)
return svc.rpo.UpdateMessage(mod)
}
func (svc message) Delete(ctx context.Context, id uint64) error {
@@ -94,7 +71,7 @@ func (svc message) Delete(ctx context.Context, id uint64) error {
// @todo verify ownership
return svc.repository.message.Delete(ctx, id)
return svc.rpo.DeleteMessageByID(id)
}
func (svc message) React(ctx context.Context, messageID uint64, reaction string) error {
@@ -113,7 +90,7 @@ func (svc message) React(ctx context.Context, messageID uint64, reaction string)
Reaction: reaction,
}
if _, err := svc.repository.reaction.Create(ctx, r); err != nil {
if _, err := svc.rpo.CreateReaction(r); err != nil {
return err
}
@@ -130,7 +107,7 @@ func (svc message) Unreact(ctx context.Context, messageID uint64, reaction strin
// @todo load reaction and verify ownership
var r *types.Reaction
return svc.repository.reaction.Delete(ctx, r.ID)
return svc.rpo.DeleteReactionByID(r.ID)
}
func (svc message) Pin(ctx context.Context, messageID uint64) error {
@@ -194,5 +171,5 @@ func (svc message) Detach(ctx context.Context, attachmentID uint64) error {
// @todo verify if current user can remove this attachment
return svc.repository.attachment.Delete(ctx, attachmentID)
return svc.rpo.DeleteAttachmentByID(attachmentID)
}

View File

@@ -8,67 +8,61 @@ import (
type (
organisation struct {
repository organisationRepository
rpo organisationRepository
}
organisationRepository interface {
FindByID(ctx context.Context, organisationID uint64) (*types.Organisation, error)
Find(ctx context.Context, filter *types.OrganisationFilter) ([]*types.Organisation, error)
Create(ctx context.Context, organisation *types.Organisation) (*types.Organisation, error)
Update(ctx context.Context, organisation *types.Organisation) (*types.Organisation, error)
deleter
archiver
repository.Transactionable
repository.Organisation
}
)
func Organisation() *organisation {
return &organisation{repository: repository.Organisation()}
return &organisation{rpo: repository.New()}
}
func (svc organisation) FindByID(ctx context.Context, id uint64) (*types.Organisation, error) {
// @todo: permission check if current user can read organisation
return svc.repository.FindByID(ctx, id)
return svc.rpo.FindOrganisationByID(id)
}
func (svc organisation) Find(ctx context.Context, filter *types.OrganisationFilter) ([]*types.Organisation, error) {
// @todo: permission check to return only organisations that organisation has access to
// @todo: actual searching not just a full select
return svc.repository.Find(ctx, filter)
return svc.rpo.FindOrganisations(filter)
}
func (svc organisation) Create(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) {
// @todo: permission check if current user can add/edit organisation
// @todo: make sure archived & deleted entries can not be edited
return svc.repository.Create(ctx, mod)
return svc.rpo.CreateOrganisation(mod)
}
func (svc organisation) Update(ctx context.Context, mod *types.Organisation) (*types.Organisation, error) {
// @todo: permission check if current user can add/edit organisation
// @todo: make sure archived & deleted entries can not be edited
return svc.repository.Update(ctx, mod)
return svc.rpo.UpdateOrganisation(mod)
}
func (svc organisation) Delete(ctx context.Context, id uint64) error {
// @todo: permissions check if current user can remove organisation
// @todo: make history unavailable
// @todo: notify users that organisation has been removed (remove from web UI)
return svc.repository.Delete(ctx, id)
return svc.rpo.DeleteOrganisationByID(id)
}
func (svc organisation) Archive(ctx context.Context, id uint64) error {
// @todo: make history unavailable
// @todo: notify users that organisation has been removed (remove from web UI)
// @todo: permissions check if current user can archive organisation
return svc.repository.Archive(ctx, id)
return svc.rpo.ArchiveOrganisationByID(id)
}
func (svc organisation) Unarchive(ctx context.Context, id uint64) error {
// @todo: make history unavailable
// @todo: notify users that organisation has been removed (remove from web UI)
// @todo: permissions check if current user can unarchive organisation
return svc.repository.Unarchive(ctx, id)
return svc.rpo.UnarchiveOrganisationByID(id)
}

View File

@@ -1,20 +0,0 @@
package service
import (
"context"
)
type (
suspender interface {
Suspend(ctx context.Context, ID uint64) error
Unsuspend(ctx context.Context, ID uint64) error
}
archiver interface {
Archive(ctx context.Context, ID uint64) error
Unarchive(ctx context.Context, ID uint64) error
}
deleter interface {
Delete(ctx context.Context, ID uint64) error
}
)

View File

@@ -8,78 +8,69 @@ import (
type (
team struct {
repository teamRepository
rpo teamRepository
}
teamRepository interface {
FindByID(ctx context.Context, teamID uint64) (*types.Team, error)
Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team, error)
Create(ctx context.Context, team *types.Team) (*types.Team, error)
Update(ctx context.Context, team *types.Team) (*types.Team, error)
Merge(ctx context.Context, teamID, targetTeamID uint64) error
Move(ctx context.Context, teamID, organisationID uint64) error
deleter
archiver
repository.Transactionable
repository.Team
}
)
func Team() *team {
return &team{repository: repository.Team()}
return &team{rpo: repository.New()}
}
func (svc team) FindByID(ctx context.Context, id uint64) (*types.Team, error) {
// @todo: permission check if current user has access to this team
return svc.repository.FindByID(ctx, id)
return svc.rpo.FindTeamByID(id)
}
func (svc team) Find(ctx context.Context, filter *types.TeamFilter) ([]*types.Team, error) {
// @todo: permission check to return only teams that current user has access to
return svc.repository.Find(ctx, filter)
return svc.rpo.FindTeams(filter)
}
func (svc team) Create(ctx context.Context, mod *types.Team) (*types.Team, error) {
// @todo: permission check if current user can add/edit team
return svc.repository.Create(ctx, mod)
return svc.rpo.CreateTeam(mod)
}
func (svc team) Update(ctx context.Context, mod *types.Team) (*types.Team, error) {
// @todo: permission check if current user can add/edit team
// @todo: make sure archived & deleted entries can not be edited
return svc.repository.Update(ctx, mod)
return svc.rpo.UpdateTeam(mod)
}
func (svc team) Delete(ctx context.Context, id uint64) error {
// @todo: make history unavailable
// @todo: notify users that team has been removed (remove from web UI)
// @todo: permissions check if current user can remove team
return svc.repository.Delete(ctx, id)
return svc.rpo.DeleteTeamByID(id)
}
func (svc team) Archive(ctx context.Context, id uint64) error {
// @todo: make history unavailable
// @todo: notify users that team has been removed (remove from web UI)
// @todo: permissions check if current user can remove team
return svc.repository.Archive(ctx, id)
return svc.rpo.ArchiveTeamByID(id)
}
func (svc team) Unarchive(ctx context.Context, id uint64) error {
// @todo: permissions check if current user can unarchive team
// @todo: make history accessible
// @todo: notify users that team has been unarchived
return svc.repository.Unarchive(ctx, id)
return svc.rpo.UnarchiveTeamByID(id)
}
func (svc team) Merge(ctx context.Context, id, targetTeamID uint64) error {
// @todo: permission check if current user can merge team
return svc.repository.Merge(ctx, id, targetTeamID)
return svc.rpo.MergeTeamByID(id, targetTeamID)
}
func (svc team) Move(ctx context.Context, id, targetOrganisationID uint64) error {
// @todo: permission check if current user can move team to another organisation
return svc.repository.Move(ctx, id, targetOrganisationID)
return svc.rpo.MoveTeamByID(id, targetOrganisationID)
}

View File

@@ -14,28 +14,22 @@ const (
type (
user struct {
repository userRepository
rpo userRepository
}
userRepository interface {
FindByUsername(ctx context.Context, username string) (*types.User, error)
FindByID(ctx context.Context, userID uint64) (*types.User, error)
Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error)
Create(ctx context.Context, user *types.User) (*types.User, error)
Update(ctx context.Context, user *types.User) (*types.User, error)
deleter
suspender
repository.Transactionable
repository.Contextable
repository.User
}
)
func User() *user {
return &user{repository: repository.User()}
return &user{rpo: repository.New()}
}
func (svc user) ValidateCredentials(ctx context.Context, username, password string) (*types.User, error) {
user, err := svc.repository.FindByUsername(ctx, username)
user, err := svc.rpo.FindUserByUsername(username)
if err != nil {
return nil, err
}
@@ -52,19 +46,26 @@ func (svc user) ValidateCredentials(ctx context.Context, username, password stri
}
func (svc user) FindByID(ctx context.Context, id uint64) (*types.User, error) {
return svc.repository.FindByID(ctx, id)
return svc.rpo.WithCtx(ctx).FindUserByID(id)
}
func (svc user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) {
return svc.repository.Find(ctx, filter)
return svc.rpo.FindUsers(filter)
}
func (svc user) Create(ctx context.Context, mod *types.User) (*types.User, error) {
return svc.repository.Create(ctx, mod)
func (svc user) Create(ctx context.Context, input *types.User) (new *types.User, err error) {
// no real need for tx here, just presenting the capabilities
return new, svc.rpo.BeginWith(ctx, func(r repository.Interfaces) error {
if new, err = r.CreateUser(input); err != nil {
return err
}
return nil
})
}
func (svc user) Update(ctx context.Context, mod *types.User) (*types.User, error) {
return svc.repository.Update(ctx, mod)
return svc.rpo.UpdateUser(mod)
}
func (svc user) validatePassword(user *types.User, password string) bool {
@@ -89,17 +90,17 @@ func (svc user) canLogin(u *types.User) bool {
func (svc user) Delete(ctx context.Context, id uint64) error {
// @todo: permissions check if current user can delete this user
// @todo: notify users that user has been removed (remove from web UI)
return svc.repository.Delete(ctx, id)
return svc.rpo.DeleteUserByID(id)
}
func (svc user) Suspend(ctx context.Context, id uint64) error {
// @todo: permissions check if current user can suspend this user
// @todo: notify users that user has been supsended (remove from web UI)
return svc.repository.Suspend(ctx, id)
return svc.rpo.SuspendUserByID(id)
}
func (svc user) Unsuspend(ctx context.Context, id uint64) error {
// @todo: permissions check if current user can unsuspend this user
// @todo: notify users that user has been unsuspended
return svc.repository.Unsuspend(ctx, id)
return svc.rpo.UnsuspendUserByID(id)
}