upd(crm): refactored db handles in repo
This commit is contained in:
@@ -55,7 +55,6 @@ func (r *content) Create(mod *types.Content) (*types.Content, error) {
|
||||
|
||||
func (r *content) Update(mod *types.Content) (*types.Content, error) {
|
||||
return mod, r.db().Replace("crm_module_content", mod)
|
||||
|
||||
}
|
||||
|
||||
func (r *content) DeleteByID(id uint64) error {
|
||||
|
||||
19
crm/repository/db.go
Normal file
19
crm/repository/db.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/titpetric/factory"
|
||||
)
|
||||
|
||||
var _db *factory.DB
|
||||
|
||||
func DB(ctxs ...context.Context) *factory.DB {
|
||||
if _db == nil {
|
||||
_db = factory.Database.MustGet()
|
||||
}
|
||||
for _, ctx := range ctxs {
|
||||
_db = _db.With(ctx)
|
||||
break
|
||||
}
|
||||
return _db
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
)
|
||||
|
||||
@@ -10,34 +9,24 @@ type (
|
||||
repository struct {
|
||||
ctx context.Context
|
||||
|
||||
// Current transaction
|
||||
tx *factory.DB
|
||||
// Get database handle
|
||||
dbh func(ctxs ...context.Context) *factory.DB
|
||||
}
|
||||
)
|
||||
|
||||
// With updates repository and database contexts
|
||||
func (r *repository) With(ctx context.Context) *repository {
|
||||
return &repository{
|
||||
res := &repository{
|
||||
ctx: ctx,
|
||||
tx: r.db().With(r.ctx),
|
||||
dbh: DB,
|
||||
}
|
||||
if r != nil {
|
||||
res.dbh = r.dbh
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (r *repository) Begin() error {
|
||||
return r.db().Begin()
|
||||
}
|
||||
|
||||
func (r *repository) Commit() error {
|
||||
return errors.Wrap(r.db().Commit(), "Can not commit changes")
|
||||
}
|
||||
|
||||
func (r *repository) Rollback() error {
|
||||
return errors.Wrap(r.db().Rollback(), "Can not rollback changes")
|
||||
}
|
||||
|
||||
// Return context-aware db handle
|
||||
func (r *repository) db() *factory.DB {
|
||||
if r.tx == nil {
|
||||
r.tx = factory.Database.MustGet().With(r.ctx)
|
||||
}
|
||||
return r.tx
|
||||
return r.dbh(r.ctx)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user