3
0

Fix repository tests & transaction handling

This commit is contained in:
Denis Arh
2018-07-30 11:48:25 +02:00
parent a6e86066f4
commit 8a2423673b
5 changed files with 15 additions and 25 deletions

View File

@@ -1,7 +1,6 @@
package repository package repository
import ( import (
"context"
"github.com/crusttech/crust/sam/types" "github.com/crusttech/crust/sam/types"
"testing" "testing"
) )
@@ -14,15 +13,14 @@ func TestAttachment(t *testing.T) {
return return
} }
rpo := Attachment() rpo := New()
ctx := context.Background()
att := &types.Attachment{} att := &types.Attachment{}
var aa []*types.Attachment var aa []*types.Attachment
att.ChannelID = 1 att.ChannelID = 1
att, err = rpo.CreateAttachment(ctx, att) att, err = rpo.CreateAttachment(att)
must(t, err) must(t, err)
if att.ChannelID != 1 { if att.ChannelID != 1 {
t.Fatal("Changes were not stored") t.Fatal("Changes were not stored")
@@ -30,23 +28,23 @@ func TestAttachment(t *testing.T) {
att.ChannelID = 2 att.ChannelID = 2
att, err = rpo.UpdateAttachment(ctx, att) att, err = rpo.UpdateAttachment(att)
must(t, err) must(t, err)
if att.ChannelID != 2 { if att.ChannelID != 2 {
t.Fatal("Changes were not stored") t.Fatal("Changes were not stored")
} }
att, err = rpo.FindAttachmentByID(ctx, att.ID) att, err = rpo.FindAttachmentByID(att.ID)
must(t, err) must(t, err)
if att.ChannelID != 2 { if att.ChannelID != 2 {
t.Fatal("Changes were not stored") t.Fatal("Changes were not stored")
} }
aa, err = rpo.FindAttachmentByRange(ctx, 2, 0, att.ID) aa, err = rpo.FindAttachmentByRange(2, 0, att.ID)
must(t, err) must(t, err)
if len(aa) == 0 { if len(aa) == 0 {
t.Fatal("No results found") t.Fatal("No results found")
} }
must(t, rpo.DeleteAttachmentByID(ctx, att.ID)) must(t, rpo.DeleteAttachmentByID(att.ID))
} }

View File

@@ -1,7 +1,6 @@
package repository package repository
import ( import (
"context"
"github.com/crusttech/crust/sam/types" "github.com/crusttech/crust/sam/types"
"testing" "testing"
) )
@@ -15,7 +14,6 @@ func TestChannel(t *testing.T) {
} }
rpo := New() rpo := New()
ctx := context.Background()
chn := &types.Channel{} chn := &types.Channel{}
var name1, name2 = "Test channel v1", "Test channel v2" var name1, name2 = "Test channel v1", "Test channel v2"

View File

@@ -1,7 +1,6 @@
package repository package repository
import ( import (
"context"
"github.com/crusttech/crust/sam/types" "github.com/crusttech/crust/sam/types"
"testing" "testing"
) )
@@ -14,8 +13,7 @@ func TestOrganisation(t *testing.T) {
return return
} }
rpo := Organisation() rpo := New()
ctx := context.Background()
org := &types.Organisation{} org := &types.Organisation{}
var name1, name2 = "Test organisation v1", "Test organisation v2" var name1, name2 = "Test organisation v1", "Test organisation v2"

View File

@@ -49,12 +49,8 @@ func (r *repository) WithCtx(ctx context.Context) Interfaces {
} }
func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error { func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error {
tx := r.tx
if tx == nil {
tx = factory.Database.MustGet().With(ctx)
}
txr := &repository{ctx: ctx, tx: tx} txr := &repository{ctx: ctx}
if err := txr.Begin(); err != nil { if err := txr.Begin(); err != nil {
return err return err
@@ -72,20 +68,21 @@ func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) erro
} }
func (r *repository) Begin() error { func (r *repository) Begin() error {
// @todo implementation return r.db().Begin()
return r.tx.Begin()
} }
func (r *repository) Commit() error { func (r *repository) Commit() error {
// @todo implementation return r.db().Commit()
return r.tx.Commit()
} }
func (r *repository) Rollback() error { func (r *repository) Rollback() error {
// @todo implementation return r.db().Rollback()
return r.tx.Rollback()
} }
func (r *repository) db() *factory.DB { func (r *repository) db() *factory.DB {
if r.tx == nil {
r.tx = factory.Database.MustGet().With(r.ctx)
}
return r.tx.With(r.ctx) return r.tx.With(r.ctx)
} }

View File

@@ -66,7 +66,6 @@ func (r *repository) CreateUser(mod *types.User) (*types.User, error) {
mod.ID = factory.Sonyflake.NextID() mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now() mod.CreatedAt = time.Now()
mod.Meta = coalesceJson(mod.Meta, []byte("{}")) mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
return mod, r.db().Insert("users", mod) return mod, r.db().Insert("users", mod)
} }