3
0

Fix repository tests & transaction handling

This commit is contained in:
Denis Arh 2018-07-30 11:48:25 +02:00
parent a6e86066f4
commit 8a2423673b
5 changed files with 15 additions and 25 deletions

View File

@ -1,7 +1,6 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"testing"
)
@ -14,15 +13,14 @@ func TestAttachment(t *testing.T) {
return
}
rpo := Attachment()
ctx := context.Background()
rpo := New()
att := &types.Attachment{}
var aa []*types.Attachment
att.ChannelID = 1
att, err = rpo.CreateAttachment(ctx, att)
att, err = rpo.CreateAttachment(att)
must(t, err)
if att.ChannelID != 1 {
t.Fatal("Changes were not stored")
@ -30,23 +28,23 @@ func TestAttachment(t *testing.T) {
att.ChannelID = 2
att, err = rpo.UpdateAttachment(ctx, att)
att, err = rpo.UpdateAttachment(att)
must(t, err)
if att.ChannelID != 2 {
t.Fatal("Changes were not stored")
}
att, err = rpo.FindAttachmentByID(ctx, att.ID)
att, err = rpo.FindAttachmentByID(att.ID)
must(t, err)
if att.ChannelID != 2 {
t.Fatal("Changes were not stored")
}
aa, err = rpo.FindAttachmentByRange(ctx, 2, 0, att.ID)
aa, err = rpo.FindAttachmentByRange(2, 0, att.ID)
must(t, err)
if len(aa) == 0 {
t.Fatal("No results found")
}
must(t, rpo.DeleteAttachmentByID(ctx, att.ID))
must(t, rpo.DeleteAttachmentByID(att.ID))
}

View File

@ -1,7 +1,6 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"testing"
)
@ -15,7 +14,6 @@ func TestChannel(t *testing.T) {
}
rpo := New()
ctx := context.Background()
chn := &types.Channel{}
var name1, name2 = "Test channel v1", "Test channel v2"

View File

@ -1,7 +1,6 @@
package repository
import (
"context"
"github.com/crusttech/crust/sam/types"
"testing"
)
@ -14,8 +13,7 @@ func TestOrganisation(t *testing.T) {
return
}
rpo := Organisation()
ctx := context.Background()
rpo := New()
org := &types.Organisation{}
var name1, name2 = "Test organisation v1", "Test organisation v2"

View File

@ -49,12 +49,8 @@ func (r *repository) WithCtx(ctx context.Context) Interfaces {
}
func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error {
tx := r.tx
if tx == nil {
tx = factory.Database.MustGet().With(ctx)
}
txr := &repository{ctx: ctx, tx: tx}
txr := &repository{ctx: ctx}
if err := txr.Begin(); err != nil {
return err
@ -72,20 +68,21 @@ func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) erro
}
func (r *repository) Begin() error {
// @todo implementation
return r.tx.Begin()
return r.db().Begin()
}
func (r *repository) Commit() error {
// @todo implementation
return r.tx.Commit()
return r.db().Commit()
}
func (r *repository) Rollback() error {
// @todo implementation
return r.tx.Rollback()
return r.db().Rollback()
}
func (r *repository) db() *factory.DB {
if r.tx == nil {
r.tx = factory.Database.MustGet().With(r.ctx)
}
return r.tx.With(r.ctx)
}

View File

@ -66,7 +66,6 @@ func (r *repository) CreateUser(mod *types.User) (*types.User, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
return mod, r.db().Insert("users", mod)
}