Apply new repo/svc patterns to auth pkg
This commit is contained in:
parent
7a037b2b2b
commit
4788e61c07
@ -1,19 +0,0 @@
|
|||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"github.com/titpetric/factory"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _db *factory.DB
|
|
||||||
|
|
||||||
func DB(ctxs ...context.Context) *factory.DB {
|
|
||||||
if _db == nil {
|
|
||||||
_db = factory.Database.MustGet()
|
|
||||||
}
|
|
||||||
for _, ctx := range ctxs {
|
|
||||||
_db = _db.With(ctx)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return _db
|
|
||||||
}
|
|
||||||
@ -8,37 +8,32 @@ import (
|
|||||||
type (
|
type (
|
||||||
repository struct {
|
repository struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
dbh *factory.DB
|
||||||
// Get database handle
|
|
||||||
dbh func(ctxs ...context.Context) *factory.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
Repository interface {
|
|
||||||
Context() context.Context
|
|
||||||
DB() *factory.DB
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// With updates repository and database contexts
|
// DB produces a contextual DB handle
|
||||||
func (r *repository) With(ctx context.Context) *repository {
|
func DB(ctx context.Context) *factory.DB {
|
||||||
res := &repository{
|
return factory.Database.MustGet().With(ctx)
|
||||||
ctx: ctx,
|
|
||||||
dbh: DB,
|
|
||||||
}
|
|
||||||
if r != nil {
|
|
||||||
res.dbh = r.dbh
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// With updates repository and database contexts
|
||||||
|
func (r *repository) With(ctx context.Context, db *factory.DB) *repository {
|
||||||
|
return &repository{
|
||||||
|
ctx: ctx,
|
||||||
|
dbh: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context returns current active repository context
|
||||||
func (r *repository) Context() context.Context {
|
func (r *repository) Context() context.Context {
|
||||||
return r.ctx
|
return r.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return context-aware db handle
|
// db returns context-aware db handle
|
||||||
func (r *repository) db() *factory.DB {
|
func (r *repository) db() *factory.DB {
|
||||||
return r.dbh(r.ctx)
|
if r.dbh != nil {
|
||||||
}
|
return r.dbh
|
||||||
func (r *repository) DB() *factory.DB {
|
}
|
||||||
return r.db()
|
return DB(r.ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/titpetric/factory"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -13,22 +14,20 @@ type (
|
|||||||
}
|
}
|
||||||
|
|
||||||
Settings interface {
|
Settings interface {
|
||||||
Repository
|
With(ctx context.Context, db *factory.DB) Settings
|
||||||
|
|
||||||
With(context.Context) Settings
|
|
||||||
|
|
||||||
Get(name string, value interface{}) (bool, error)
|
Get(name string, value interface{}) (bool, error)
|
||||||
Set(name string, value interface{}) error
|
Set(name string, value interface{}) error
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewSettings(ctx context.Context) Settings {
|
func NewSettings(ctx context.Context, db *factory.DB) Settings {
|
||||||
return (&settings{}).With(ctx)
|
return (&settings{}).With(ctx, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *settings) With(ctx context.Context) Settings {
|
func (r *settings) With(ctx context.Context, db *factory.DB) Settings {
|
||||||
return &settings{
|
return &settings{
|
||||||
repository: r.repository.With(ctx),
|
repository: r.repository.With(ctx, db),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,14 +8,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
user struct {
|
UserRepository interface {
|
||||||
*repository
|
With(ctx context.Context, db *factory.DB) UserRepository
|
||||||
}
|
|
||||||
|
|
||||||
User interface {
|
|
||||||
Repository
|
|
||||||
|
|
||||||
With(context.Context) User
|
|
||||||
|
|
||||||
FindUserByEmail(email string) (*types.User, error)
|
FindUserByEmail(email string) (*types.User, error)
|
||||||
FindUserByUsername(username string) (*types.User, error)
|
FindUserByUsername(username string) (*types.User, error)
|
||||||
@ -28,6 +22,10 @@ type (
|
|||||||
UnsuspendUserByID(id uint64) error
|
UnsuspendUserByID(id uint64) error
|
||||||
DeleteUserByID(id uint64) error
|
DeleteUserByID(id uint64) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user struct {
|
||||||
|
*repository
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -40,14 +38,12 @@ const (
|
|||||||
ErrUserNotFound = repositoryError("UserNotFound")
|
ErrUserNotFound = repositoryError("UserNotFound")
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUser(ctx context.Context) User {
|
func User(ctx context.Context, db *factory.DB) UserRepository {
|
||||||
return (&user{}).With(ctx)
|
return (&user{}).With(ctx, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *user) With(ctx context.Context) User {
|
func (r *user) With(ctx context.Context, db *factory.DB) UserRepository {
|
||||||
return &user{
|
return &user{repository: r.repository.With(ctx, db)}
|
||||||
repository: r.repository.With(ctx),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *user) FindUserByUsername(username string) (*types.User, error) {
|
func (r *user) FindUserByUsername(username string) (*types.User, error) {
|
||||||
|
|||||||
@ -163,8 +163,9 @@ func (c *openIdConnect) HandleOAuth2Callback(w http.ResponseWriter, r *http.Requ
|
|||||||
u.Claims(p)
|
u.Claims(p)
|
||||||
|
|
||||||
var user = &types.User{
|
var user = &types.User{
|
||||||
Email: p.Email,
|
SatosaID: p.Sub,
|
||||||
Name: p.Name,
|
Email: p.Email,
|
||||||
|
Name: p.Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
if user, err = c.userService.With(ctx).FindOrCreate(user); err != nil {
|
if user, err = c.userService.With(ctx).FindOrCreate(user); err != nil {
|
||||||
|
|||||||
@ -26,7 +26,7 @@ func MountRoutes(oidcConfig *config.OIDC, jwtAuth jwtEncodeCookieSetter) func(ch
|
|||||||
var userSvc = service.User()
|
var userSvc = service.User()
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
|
||||||
oidc, err := OpenIdConnect(ctx, oidcConfig, userSvc, jwtAuth, repository.NewSettings(ctx))
|
oidc, err := OpenIdConnect(ctx, oidcConfig, userSvc, jwtAuth, repository.NewSettings(ctx, repository.DB(ctx)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print("Could not initialize OIDC:", err.Error())
|
log.Print("Could not initialize OIDC:", err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,16 +5,23 @@ import (
|
|||||||
|
|
||||||
"github.com/crusttech/crust/auth/repository"
|
"github.com/crusttech/crust/auth/repository"
|
||||||
"github.com/crusttech/crust/auth/types"
|
"github.com/crusttech/crust/auth/types"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/titpetric/factory"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ErrUserInvalidCredentials = serviceError("UserInvalidCredentials")
|
ErrUserInvalidCredentials = serviceError("UserInvalidCredentials")
|
||||||
ErrUserLocked = serviceError("UserLocked")
|
ErrUserLocked = serviceError("UserLocked")
|
||||||
|
|
||||||
|
uuidLength = 36
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
user struct {
|
user struct {
|
||||||
repository repository.User
|
db *factory.DB
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
|
user repository.UserRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
UserService interface {
|
UserService interface {
|
||||||
@ -32,19 +39,21 @@ type (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func User() UserService {
|
func User() UserService {
|
||||||
return &user{
|
return (&user{}).With(context.Background())
|
||||||
repository.NewUser(context.Background()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svc *user) With(ctx context.Context) UserService {
|
func (svc *user) With(ctx context.Context) UserService {
|
||||||
|
db := repository.DB(ctx)
|
||||||
|
|
||||||
return &user{
|
return &user{
|
||||||
svc.repository.With(ctx),
|
db: db,
|
||||||
|
ctx: ctx,
|
||||||
|
user: repository.User(ctx, db),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svc *user) ValidateCredentials(username, password string) (*types.User, error) {
|
func (svc *user) ValidateCredentials(username, password string) (*types.User, error) {
|
||||||
user, err := svc.repository.FindUserByUsername(username)
|
user, err := svc.user.FindUserByUsername(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -61,42 +70,47 @@ func (svc *user) ValidateCredentials(username, password string) (*types.User, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (svc *user) FindByID(id uint64) (*types.User, error) {
|
func (svc *user) FindByID(id uint64) (*types.User, error) {
|
||||||
return svc.repository.FindUserByID(id)
|
return svc.user.FindUserByID(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) {
|
func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) {
|
||||||
return svc.repository.FindUsers(filter)
|
return svc.user.FindUsers(filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds if user with a specific satosa id exists and returns it otherwise it creates a fresh one
|
// Finds if user with a specific satosa id exists and returns it otherwise it creates a fresh one
|
||||||
func (svc *user) FindOrCreate(user *types.User) (out *types.User, err error) {
|
func (svc *user) FindOrCreate(user *types.User) (out *types.User, err error) {
|
||||||
//return out, svc.repository.DB().Transaction(func() error {
|
return out, svc.db.Transaction(func() error {
|
||||||
out, err = svc.repository.FindUserBySatosaID(user.SatosaID)
|
if len(user.SatosaID) != uuidLength {
|
||||||
|
// @todo uuid format check
|
||||||
|
return errors.Errorf("Invalid UUID value (%v) for SATOSA ID", user.SatosaID)
|
||||||
|
}
|
||||||
|
|
||||||
if err == repository.ErrUserNotFound {
|
out, err = svc.user.FindUserBySatosaID(user.SatosaID)
|
||||||
out, err = svc.repository.CreateUser(user)
|
|
||||||
return out, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err == repository.ErrUserNotFound {
|
||||||
// FindUserBySatosaID error
|
out, err = svc.user.CreateUser(user)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// @todo need to be more selective with fields we update...
|
if err != nil {
|
||||||
out, err = svc.repository.UpdateUser(out)
|
// FindUserBySatosaID error
|
||||||
if err != nil {
|
return err
|
||||||
return nil, err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
// @todo need to be more selective with fields we update...
|
||||||
//})
|
out, err = svc.user.UpdateUser(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svc *user) Create(input *types.User) (user *types.User, err error) {
|
func (svc *user) Create(input *types.User) (out *types.User, err error) {
|
||||||
return user, svc.repository.DB().Transaction(func() error {
|
return out, svc.db.Transaction(func() error {
|
||||||
// Encrypt user password
|
// Encrypt user password
|
||||||
if user, err = svc.repository.CreateUser(input); err != nil {
|
if out, err = svc.user.CreateUser(input); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -104,7 +118,7 @@ func (svc *user) Create(input *types.User) (user *types.User, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (svc *user) Update(mod *types.User) (*types.User, error) {
|
func (svc *user) Update(mod *types.User) (*types.User, error) {
|
||||||
return svc.repository.UpdateUser(mod)
|
return svc.user.UpdateUser(mod)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svc *user) canLogin(u *types.User) bool {
|
func (svc *user) canLogin(u *types.User) bool {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user