3
0

Apply new repo/svc patterns to auth pkg

This commit is contained in:
Denis Arh 2018-09-27 14:43:10 +02:00
parent 7a037b2b2b
commit 4788e61c07
7 changed files with 81 additions and 95 deletions

View File

@ -1,19 +0,0 @@
package repository
import (
"context"
"github.com/titpetric/factory"
)
var _db *factory.DB
func DB(ctxs ...context.Context) *factory.DB {
if _db == nil {
_db = factory.Database.MustGet()
}
for _, ctx := range ctxs {
_db = _db.With(ctx)
break
}
return _db
}

View File

@ -8,37 +8,32 @@ import (
type ( type (
repository struct { repository struct {
ctx context.Context ctx context.Context
dbh *factory.DB
// Get database handle
dbh func(ctxs ...context.Context) *factory.DB
}
Repository interface {
Context() context.Context
DB() *factory.DB
} }
) )
// With updates repository and database contexts // DB produces a contextual DB handle
func (r *repository) With(ctx context.Context) *repository { func DB(ctx context.Context) *factory.DB {
res := &repository{ return factory.Database.MustGet().With(ctx)
ctx: ctx,
dbh: DB,
}
if r != nil {
res.dbh = r.dbh
}
return res
} }
// With updates repository and database contexts
func (r *repository) With(ctx context.Context, db *factory.DB) *repository {
return &repository{
ctx: ctx,
dbh: db,
}
}
// Context returns current active repository context
func (r *repository) Context() context.Context { func (r *repository) Context() context.Context {
return r.ctx return r.ctx
} }
// Return context-aware db handle // db returns context-aware db handle
func (r *repository) db() *factory.DB { func (r *repository) db() *factory.DB {
return r.dbh(r.ctx) if r.dbh != nil {
} return r.dbh
func (r *repository) DB() *factory.DB { }
return r.db() return DB(r.ctx)
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/titpetric/factory"
) )
type ( type (
@ -13,22 +14,20 @@ type (
} }
Settings interface { Settings interface {
Repository With(ctx context.Context, db *factory.DB) Settings
With(context.Context) Settings
Get(name string, value interface{}) (bool, error) Get(name string, value interface{}) (bool, error)
Set(name string, value interface{}) error Set(name string, value interface{}) error
} }
) )
func NewSettings(ctx context.Context) Settings { func NewSettings(ctx context.Context, db *factory.DB) Settings {
return (&settings{}).With(ctx) return (&settings{}).With(ctx, db)
} }
func (r *settings) With(ctx context.Context) Settings { func (r *settings) With(ctx context.Context, db *factory.DB) Settings {
return &settings{ return &settings{
repository: r.repository.With(ctx), repository: r.repository.With(ctx, db),
} }
} }

View File

@ -8,14 +8,8 @@ import (
) )
type ( type (
user struct { UserRepository interface {
*repository With(ctx context.Context, db *factory.DB) UserRepository
}
User interface {
Repository
With(context.Context) User
FindUserByEmail(email string) (*types.User, error) FindUserByEmail(email string) (*types.User, error)
FindUserByUsername(username string) (*types.User, error) FindUserByUsername(username string) (*types.User, error)
@ -28,6 +22,10 @@ type (
UnsuspendUserByID(id uint64) error UnsuspendUserByID(id uint64) error
DeleteUserByID(id uint64) error DeleteUserByID(id uint64) error
} }
user struct {
*repository
}
) )
const ( const (
@ -40,14 +38,12 @@ const (
ErrUserNotFound = repositoryError("UserNotFound") ErrUserNotFound = repositoryError("UserNotFound")
) )
func NewUser(ctx context.Context) User { func User(ctx context.Context, db *factory.DB) UserRepository {
return (&user{}).With(ctx) return (&user{}).With(ctx, db)
} }
func (r *user) With(ctx context.Context) User { func (r *user) With(ctx context.Context, db *factory.DB) UserRepository {
return &user{ return &user{repository: r.repository.With(ctx, db)}
repository: r.repository.With(ctx),
}
} }
func (r *user) FindUserByUsername(username string) (*types.User, error) { func (r *user) FindUserByUsername(username string) (*types.User, error) {

View File

@ -163,8 +163,9 @@ func (c *openIdConnect) HandleOAuth2Callback(w http.ResponseWriter, r *http.Requ
u.Claims(p) u.Claims(p)
var user = &types.User{ var user = &types.User{
Email: p.Email, SatosaID: p.Sub,
Name: p.Name, Email: p.Email,
Name: p.Name,
} }
if user, err = c.userService.With(ctx).FindOrCreate(user); err != nil { if user, err = c.userService.With(ctx).FindOrCreate(user); err != nil {

View File

@ -26,7 +26,7 @@ func MountRoutes(oidcConfig *config.OIDC, jwtAuth jwtEncodeCookieSetter) func(ch
var userSvc = service.User() var userSvc = service.User()
var ctx = context.Background() var ctx = context.Background()
oidc, err := OpenIdConnect(ctx, oidcConfig, userSvc, jwtAuth, repository.NewSettings(ctx)) oidc, err := OpenIdConnect(ctx, oidcConfig, userSvc, jwtAuth, repository.NewSettings(ctx, repository.DB(ctx)))
if err != nil { if err != nil {
log.Print("Could not initialize OIDC:", err.Error()) log.Print("Could not initialize OIDC:", err.Error())
} }

View File

@ -5,16 +5,23 @@ import (
"github.com/crusttech/crust/auth/repository" "github.com/crusttech/crust/auth/repository"
"github.com/crusttech/crust/auth/types" "github.com/crusttech/crust/auth/types"
"github.com/pkg/errors"
"github.com/titpetric/factory"
) )
const ( const (
ErrUserInvalidCredentials = serviceError("UserInvalidCredentials") ErrUserInvalidCredentials = serviceError("UserInvalidCredentials")
ErrUserLocked = serviceError("UserLocked") ErrUserLocked = serviceError("UserLocked")
uuidLength = 36
) )
type ( type (
user struct { user struct {
repository repository.User db *factory.DB
ctx context.Context
user repository.UserRepository
} }
UserService interface { UserService interface {
@ -32,19 +39,21 @@ type (
) )
func User() UserService { func User() UserService {
return &user{ return (&user{}).With(context.Background())
repository.NewUser(context.Background()),
}
} }
func (svc *user) With(ctx context.Context) UserService { func (svc *user) With(ctx context.Context) UserService {
db := repository.DB(ctx)
return &user{ return &user{
svc.repository.With(ctx), db: db,
ctx: ctx,
user: repository.User(ctx, db),
} }
} }
func (svc *user) ValidateCredentials(username, password string) (*types.User, error) { func (svc *user) ValidateCredentials(username, password string) (*types.User, error) {
user, err := svc.repository.FindUserByUsername(username) user, err := svc.user.FindUserByUsername(username)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,42 +70,47 @@ func (svc *user) ValidateCredentials(username, password string) (*types.User, er
} }
func (svc *user) FindByID(id uint64) (*types.User, error) { func (svc *user) FindByID(id uint64) (*types.User, error) {
return svc.repository.FindUserByID(id) return svc.user.FindUserByID(id)
} }
func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) { func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) {
return svc.repository.FindUsers(filter) return svc.user.FindUsers(filter)
} }
// Finds if user with a specific satosa id exists and returns it otherwise it creates a fresh one // Finds if user with a specific satosa id exists and returns it otherwise it creates a fresh one
func (svc *user) FindOrCreate(user *types.User) (out *types.User, err error) { func (svc *user) FindOrCreate(user *types.User) (out *types.User, err error) {
//return out, svc.repository.DB().Transaction(func() error { return out, svc.db.Transaction(func() error {
out, err = svc.repository.FindUserBySatosaID(user.SatosaID) if len(user.SatosaID) != uuidLength {
// @todo uuid format check
return errors.Errorf("Invalid UUID value (%v) for SATOSA ID", user.SatosaID)
}
if err == repository.ErrUserNotFound { out, err = svc.user.FindUserBySatosaID(user.SatosaID)
out, err = svc.repository.CreateUser(user)
return out, err
}
if err != nil { if err == repository.ErrUserNotFound {
// FindUserBySatosaID error out, err = svc.user.CreateUser(user)
return nil, err return err
} }
// @todo need to be more selective with fields we update... if err != nil {
out, err = svc.repository.UpdateUser(out) // FindUserBySatosaID error
if err != nil { return err
return nil, err }
}
return out, nil // @todo need to be more selective with fields we update...
//}) out, err = svc.user.UpdateUser(out)
if err != nil {
return err
}
return nil
})
} }
func (svc *user) Create(input *types.User) (user *types.User, err error) { func (svc *user) Create(input *types.User) (out *types.User, err error) {
return user, svc.repository.DB().Transaction(func() error { return out, svc.db.Transaction(func() error {
// Encrypt user password // Encrypt user password
if user, err = svc.repository.CreateUser(input); err != nil { if out, err = svc.user.CreateUser(input); err != nil {
return err return err
} }
return nil return nil
@ -104,7 +118,7 @@ func (svc *user) Create(input *types.User) (user *types.User, err error) {
} }
func (svc *user) Update(mod *types.User) (*types.User, error) { func (svc *user) Update(mod *types.User) (*types.User, error) {
return svc.repository.UpdateUser(mod) return svc.user.UpdateUser(mod)
} }
func (svc *user) canLogin(u *types.User) bool { func (svc *user) canLogin(u *types.User) bool {