Fix repository tests & transaction handling
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"github.com/crusttech/crust/sam/types"
|
"github.com/crusttech/crust/sam/types"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -14,15 +13,14 @@ func TestAttachment(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rpo := Attachment()
|
rpo := New()
|
||||||
ctx := context.Background()
|
|
||||||
att := &types.Attachment{}
|
att := &types.Attachment{}
|
||||||
|
|
||||||
var aa []*types.Attachment
|
var aa []*types.Attachment
|
||||||
|
|
||||||
att.ChannelID = 1
|
att.ChannelID = 1
|
||||||
|
|
||||||
att, err = rpo.CreateAttachment(ctx, att)
|
att, err = rpo.CreateAttachment(att)
|
||||||
must(t, err)
|
must(t, err)
|
||||||
if att.ChannelID != 1 {
|
if att.ChannelID != 1 {
|
||||||
t.Fatal("Changes were not stored")
|
t.Fatal("Changes were not stored")
|
||||||
@@ -30,23 +28,23 @@ func TestAttachment(t *testing.T) {
|
|||||||
|
|
||||||
att.ChannelID = 2
|
att.ChannelID = 2
|
||||||
|
|
||||||
att, err = rpo.UpdateAttachment(ctx, att)
|
att, err = rpo.UpdateAttachment(att)
|
||||||
must(t, err)
|
must(t, err)
|
||||||
if att.ChannelID != 2 {
|
if att.ChannelID != 2 {
|
||||||
t.Fatal("Changes were not stored")
|
t.Fatal("Changes were not stored")
|
||||||
}
|
}
|
||||||
|
|
||||||
att, err = rpo.FindAttachmentByID(ctx, att.ID)
|
att, err = rpo.FindAttachmentByID(att.ID)
|
||||||
must(t, err)
|
must(t, err)
|
||||||
if att.ChannelID != 2 {
|
if att.ChannelID != 2 {
|
||||||
t.Fatal("Changes were not stored")
|
t.Fatal("Changes were not stored")
|
||||||
}
|
}
|
||||||
|
|
||||||
aa, err = rpo.FindAttachmentByRange(ctx, 2, 0, att.ID)
|
aa, err = rpo.FindAttachmentByRange(2, 0, att.ID)
|
||||||
must(t, err)
|
must(t, err)
|
||||||
if len(aa) == 0 {
|
if len(aa) == 0 {
|
||||||
t.Fatal("No results found")
|
t.Fatal("No results found")
|
||||||
}
|
}
|
||||||
|
|
||||||
must(t, rpo.DeleteAttachmentByID(ctx, att.ID))
|
must(t, rpo.DeleteAttachmentByID(att.ID))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"github.com/crusttech/crust/sam/types"
|
"github.com/crusttech/crust/sam/types"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -15,7 +14,6 @@ func TestChannel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rpo := New()
|
rpo := New()
|
||||||
ctx := context.Background()
|
|
||||||
chn := &types.Channel{}
|
chn := &types.Channel{}
|
||||||
|
|
||||||
var name1, name2 = "Test channel v1", "Test channel v2"
|
var name1, name2 = "Test channel v1", "Test channel v2"
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"github.com/crusttech/crust/sam/types"
|
"github.com/crusttech/crust/sam/types"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -14,8 +13,7 @@ func TestOrganisation(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rpo := Organisation()
|
rpo := New()
|
||||||
ctx := context.Background()
|
|
||||||
org := &types.Organisation{}
|
org := &types.Organisation{}
|
||||||
|
|
||||||
var name1, name2 = "Test organisation v1", "Test organisation v2"
|
var name1, name2 = "Test organisation v1", "Test organisation v2"
|
||||||
|
|||||||
@@ -49,12 +49,8 @@ func (r *repository) WithCtx(ctx context.Context) Interfaces {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error {
|
func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) error {
|
||||||
tx := r.tx
|
|
||||||
if tx == nil {
|
|
||||||
tx = factory.Database.MustGet().With(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
txr := &repository{ctx: ctx, tx: tx}
|
txr := &repository{ctx: ctx}
|
||||||
|
|
||||||
if err := txr.Begin(); err != nil {
|
if err := txr.Begin(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -72,20 +68,21 @@ func (r *repository) BeginWith(ctx context.Context, callback BeginCallback) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *repository) Begin() error {
|
func (r *repository) Begin() error {
|
||||||
// @todo implementation
|
return r.db().Begin()
|
||||||
return r.tx.Begin()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *repository) Commit() error {
|
func (r *repository) Commit() error {
|
||||||
// @todo implementation
|
return r.db().Commit()
|
||||||
return r.tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *repository) Rollback() error {
|
func (r *repository) Rollback() error {
|
||||||
// @todo implementation
|
return r.db().Rollback()
|
||||||
return r.tx.Rollback()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *repository) db() *factory.DB {
|
func (r *repository) db() *factory.DB {
|
||||||
|
if r.tx == nil {
|
||||||
|
r.tx = factory.Database.MustGet().With(r.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
return r.tx.With(r.ctx)
|
return r.tx.With(r.ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ func (r *repository) CreateUser(mod *types.User) (*types.User, error) {
|
|||||||
mod.ID = factory.Sonyflake.NextID()
|
mod.ID = factory.Sonyflake.NextID()
|
||||||
mod.CreatedAt = time.Now()
|
mod.CreatedAt = time.Now()
|
||||||
mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
|
mod.Meta = coalesceJson(mod.Meta, []byte("{}"))
|
||||||
|
|
||||||
return mod, r.db().Insert("users", mod)
|
return mod, r.db().Insert("users", mod)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user