diff --git a/crm/repository/content.go b/crm/repository/content.go index e6759d5a8..5a357270e 100644 --- a/crm/repository/content.go +++ b/crm/repository/content.go @@ -55,7 +55,6 @@ func (r *content) Create(mod *types.Content) (*types.Content, error) { func (r *content) Update(mod *types.Content) (*types.Content, error) { return mod, r.db().Replace("crm_module_content", mod) - } func (r *content) DeleteByID(id uint64) error { diff --git a/crm/repository/db.go b/crm/repository/db.go new file mode 100644 index 000000000..828d80514 --- /dev/null +++ b/crm/repository/db.go @@ -0,0 +1,19 @@ +package repository + +import ( + "context" + "github.com/titpetric/factory" +) + +var _db *factory.DB + +func DB(ctxs ...context.Context) *factory.DB { + if _db == nil { + _db = factory.Database.MustGet() + } + for _, ctx := range ctxs { + _db = _db.With(ctx) + break + } + return _db +} diff --git a/crm/repository/repository.go b/crm/repository/repository.go index c8c3827fd..682fd6a10 100644 --- a/crm/repository/repository.go +++ b/crm/repository/repository.go @@ -2,7 +2,6 @@ package repository import ( "context" - "github.com/pkg/errors" "github.com/titpetric/factory" ) @@ -10,34 +9,24 @@ type ( repository struct { ctx context.Context - // Current transaction - tx *factory.DB + // Get database handle + dbh func(ctxs ...context.Context) *factory.DB } ) // With updates repository and database contexts func (r *repository) With(ctx context.Context) *repository { - return &repository{ + res := &repository{ ctx: ctx, - tx: r.db().With(r.ctx), + dbh: DB, } + if r != nil { + res.dbh = r.dbh + } + return res } -func (r *repository) Begin() error { - return r.db().Begin() -} - -func (r *repository) Commit() error { - return errors.Wrap(r.db().Commit(), "Can not commit changes") -} - -func (r *repository) Rollback() error { - return errors.Wrap(r.db().Rollback(), "Can not rollback changes") -} - +// Return context-aware db handle func (r *repository) db() *factory.DB { - if r.tx == nil { - r.tx = factory.Database.MustGet().With(r.ctx) - } - return r.tx + return r.dbh(r.ctx) }