From 8a2423673bef1ea10e27d5cb4159ba8feee5e88e Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Mon, 30 Jul 2018 11:48:25 +0200 Subject: [PATCH] Fix repository tests & transaction handling --- sam/repository/attachment_test.go | 14 ++++++-------- sam/repository/channel_test.go | 2 -- sam/repository/organisation_test.go | 4 +--- sam/repository/repository.go | 19 ++++++++----------- sam/repository/user.go | 1 - 5 files changed, 15 insertions(+), 25 deletions(-) diff --git a/sam/repository/attachment_test.go b/sam/repository/attachment_test.go index 4fbc75201..bc7246762 100644 --- a/sam/repository/attachment_test.go +++ b/sam/repository/attachment_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "testing" ) @@ -14,15 +13,14 @@ func TestAttachment(t *testing.T) { return } - rpo := Attachment() - ctx := context.Background() + rpo := New() att := &types.Attachment{} var aa []*types.Attachment att.ChannelID = 1 - att, err = rpo.CreateAttachment(ctx, att) + att, err = rpo.CreateAttachment(att) must(t, err) if att.ChannelID != 1 { t.Fatal("Changes were not stored") @@ -30,23 +28,23 @@ func TestAttachment(t *testing.T) { att.ChannelID = 2 - att, err = rpo.UpdateAttachment(ctx, att) + att, err = rpo.UpdateAttachment(att) must(t, err) if att.ChannelID != 2 { t.Fatal("Changes were not stored") } - att, err = rpo.FindAttachmentByID(ctx, att.ID) + att, err = rpo.FindAttachmentByID(att.ID) must(t, err) if att.ChannelID != 2 { t.Fatal("Changes were not stored") } - aa, err = rpo.FindAttachmentByRange(ctx, 2, 0, att.ID) + aa, err = rpo.FindAttachmentByRange(2, 0, att.ID) must(t, err) if len(aa) == 0 { t.Fatal("No results found") } - must(t, rpo.DeleteAttachmentByID(ctx, att.ID)) + must(t, rpo.DeleteAttachmentByID(att.ID)) } diff --git a/sam/repository/channel_test.go b/sam/repository/channel_test.go index 7d2037642..19c1d1040 100644 --- a/sam/repository/channel_test.go +++ b/sam/repository/channel_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "testing" ) @@ -15,7 +14,6 @@ func TestChannel(t *testing.T) { } rpo := New() - ctx := context.Background() chn := &types.Channel{} var name1, name2 = "Test channel v1", "Test channel v2" diff --git a/sam/repository/organisation_test.go b/sam/repository/organisation_test.go index 9d5a0626f..47cf2aafa 100644 --- a/sam/repository/organisation_test.go +++ b/sam/repository/organisation_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "github.com/crusttech/crust/sam/types" "testing" ) @@ -14,8 +13,7 @@ func TestOrganisation(t *testing.T) { return } - rpo := Organisation() - ctx := context.Background() + rpo := New() org := &types.Organisation{} var name1, name2 = "Test organisation v1", "Test organisation v2" diff --git a/sam/repository/repository.go b/sam/repository/repository.go index b546d3521..caf20f533 100644 --- a/sam/repository/repository.go +++ b/sam/repository/repository.go @@ -49,12 +49,8 @@ func (r *repository) WithCtx(ctx context.Context) Interfaces { } func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error { - tx := r.tx - if tx == nil { - tx = factory.Database.MustGet().With(ctx) - } - txr := &repository{ctx: ctx, tx: tx} + txr := &repository{ctx: ctx} if err := txr.Begin(); err != nil { return err @@ -72,20 +68,21 @@ func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) erro } func (r *repository) Begin() error { - // @todo implementation - return r.tx.Begin() + return r.db().Begin() } func (r *repository) Commit() error { - // @todo implementation - return r.tx.Commit() + return r.db().Commit() } func (r *repository) Rollback() error { - // @todo implementation - return r.tx.Rollback() + return r.db().Rollback() } func (r *repository) db() *factory.DB { + if r.tx == nil { + r.tx = factory.Database.MustGet().With(r.ctx) + } + return r.tx.With(r.ctx) } diff --git a/sam/repository/user.go b/sam/repository/user.go index c3d0de67d..dbf716a90 100644 --- a/sam/repository/user.go +++ b/sam/repository/user.go @@ -66,7 +66,6 @@ func (r *repository) CreateUser(mod *types.User) (*types.User, error) { mod.ID = factory.Sonyflake.NextID() mod.CreatedAt = time.Now() mod.Meta = coalesceJson(mod.Meta, []byte("{}")) - return mod, r.db().Insert("users", mod) }