From 980b6d581cc4e5a36fdf35430ad6d7dcffcee6c9 Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Sat, 27 Apr 2019 13:14:03 +0200 Subject: [PATCH] Refactor JWT encoder/handler Handler is no longer passed as argument into routes etc but initialized in the Init() and stored into auth.DefaultJwtHandler. --- crm/routes.go | 11 +++++++---- crm/start.go | 27 +++++++++++++-------------- internal/auth/jwt.go | 4 ++++ messaging/routes.go | 11 +++++++---- messaging/start.go | 23 +++++++++++------------ system/rest/auth.go | 4 ++-- system/rest/auth_external.go | 4 ++-- system/rest/auth_internal.go | 4 ++-- system/rest/router.go | 8 ++++---- system/routes.go | 14 +++++++++----- system/routes_test.go | 2 +- system/start.go | 28 +++++++++++++++------------- 12 files changed, 77 insertions(+), 63 deletions(-) diff --git a/crm/routes.go b/crm/routes.go index 763d2c9f4..4d5e30f6d 100644 --- a/crm/routes.go +++ b/crm/routes.go @@ -12,19 +12,22 @@ import ( "github.com/crusttech/crust/internal/routes" ) -func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux { +func Routes(ctx context.Context) *chi.Mux { r := chi.NewRouter() middleware.Mount(ctx, r, flags.http) - MountRoutes(ctx, r, th) + MountRoutes(ctx, r) routes.Print(r) middleware.MountSystemRoutes(ctx, r, flags.http) return r } -func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) { +func MountRoutes(ctx context.Context, r chi.Router) { // Only protect application routes with JWT r.Group(func(r chi.Router) { - r.Use(th.Verifier(), th.Authenticator()) + r.Use( + auth.DefaultJwtHandler.Verifier(), + auth.DefaultJwtHandler.Authenticator(), + ) mountRoutes(r, flags.http, rest.MountRoutes()) }) } diff --git a/crm/start.go b/crm/start.go index 7d2d4fa57..decbce70f 100644 --- a/crm/start.go +++ b/crm/start.go @@ -18,16 +18,16 @@ import ( "github.com/crusttech/crust/internal/metrics" ) -func Init(ctx context.Context) error { +func Init(ctx context.Context) (err error) { // validate configuration - if err := flags.Validate(); err != nil { - return err + if err = flags.Validate(); err != nil { + return } mail.SetupDialer(flags.smtp) - if err := InitDatabase(ctx); err != nil { - return err + if err = InitDatabase(ctx); err != nil { + return } // configure resputil options @@ -39,6 +39,13 @@ func Init(ctx context.Context) error { }, }) + // Use JWT secret for hmac signer for now + auth.DefaultSigner = auth.HmacSigner(flags.jwt.Secret) + auth.DefaultJwtHandler, err = auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) + if err != nil { + return err + } + // Don't change this to init(), it needs Database return service.Init() } @@ -69,15 +76,7 @@ func StartRestAPI(ctx context.Context) error { go metrics.NewMonitor(flags.monitor.Interval) } - // Use JWT secret for hmac signer for now - auth.DefaultSigner = auth.HmacSigner(flags.jwt.Secret) - - jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) - if err != nil { - return errors.Wrap(err, "Error creating JWT Auth") - } - - go http.Serve(listener, Routes(ctx, jwtAuth)) + go http.Serve(listener, Routes(ctx)) <-ctx.Done() return nil diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index bc1d75cda..dfb54f6e0 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -22,6 +22,10 @@ type ( } ) +var ( + DefaultJwtHandler TokenHandler +) + func JWT(secret string, expiry int64) (jwt *token, err error) { if len(secret) == 0 { return nil, errors.New("JWT secret missing") diff --git a/messaging/routes.go b/messaging/routes.go index e6b6e2176..3381b851c 100644 --- a/messaging/routes.go +++ b/messaging/routes.go @@ -13,19 +13,22 @@ import ( "github.com/crusttech/crust/messaging/websocket" ) -func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux { +func Routes(ctx context.Context) *chi.Mux { r := chi.NewRouter() middleware.Mount(ctx, r, flags.http) - MountRoutes(ctx, r, th) + MountRoutes(ctx, r) routes.Print(r) middleware.MountSystemRoutes(ctx, r, flags.http) return r } -func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) { +func MountRoutes(ctx context.Context, r chi.Router) { // Only protect application routes with JWT r.Group(func(r chi.Router) { - r.Use(th.Verifier(), th.Authenticator()) + r.Use( + auth.DefaultJwtHandler.Verifier(), + auth.DefaultJwtHandler.Authenticator(), + ) mountRoutes(r, flags.http, rest.MountRoutes(), websocket.MountRoutes(ctx, flags.repository)) }) } diff --git a/messaging/start.go b/messaging/start.go index 5fcd94a05..f772e3405 100644 --- a/messaging/start.go +++ b/messaging/start.go @@ -18,15 +18,15 @@ import ( "github.com/crusttech/crust/messaging/internal/service" ) -func Init(ctx context.Context) error { +func Init(ctx context.Context) (err error) { // validate configuration - if err := flags.Validate(); err != nil { + if err = flags.Validate(); err != nil { return err } mail.SetupDialer(flags.smtp) - if err := InitDatabase(ctx); err != nil { + if err = InitDatabase(ctx); err != nil { return err } @@ -39,6 +39,13 @@ func Init(ctx context.Context) error { }, }) + // Use JWT secret for hmac signer for now + auth.DefaultSigner = auth.HmacSigner(flags.jwt.Secret) + auth.DefaultJwtHandler, err = auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) + if err != nil { + return + } + // Don't change this, it needs Database service.Init() @@ -71,15 +78,7 @@ func StartRestAPI(ctx context.Context) error { go metrics.NewMonitor(flags.monitor.Interval) } - // Use JWT secret for hmac signer for now - auth.DefaultSigner = auth.HmacSigner(flags.jwt.Secret) - - jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) - if err != nil { - return errors.Wrap(err, "Error creating JWT Auth") - } - - go http.Serve(listener, Routes(ctx, jwtAuth)) + go http.Serve(listener, Routes(ctx)) <-ctx.Done() return nil diff --git a/system/rest/auth.go b/system/rest/auth.go index 328e9085e..19b044cb4 100644 --- a/system/rest/auth.go +++ b/system/rest/auth.go @@ -37,9 +37,9 @@ type ( } ) -func (Auth) New(tenc auth.TokenEncoder) *Auth { +func (Auth) New() *Auth { return &Auth{ - jwt: tenc, + jwt: auth.DefaultJwtHandler, authSettings: service.DefaultAuthSettings, authSvc: service.DefaultAuth, } diff --git a/system/rest/auth_external.go b/system/rest/auth_external.go index 504be3d18..f4aa528d4 100644 --- a/system/rest/auth_external.go +++ b/system/rest/auth_external.go @@ -29,10 +29,10 @@ const ( externalAuthBaseUrl = "/auth/external" ) -func NewSocial(jwtEncoder auth.TokenEncoder) *ExternalAuth { +func NewSocial() *ExternalAuth { return &ExternalAuth{ auth: service.DefaultAuth, - jwtEncoder: jwtEncoder, + jwtEncoder: auth.DefaultJwtHandler, } } diff --git a/system/rest/auth_internal.go b/system/rest/auth_internal.go index db95d72d6..4324b01d9 100644 --- a/system/rest/auth_internal.go +++ b/system/rest/auth_internal.go @@ -30,9 +30,9 @@ type ( } ) -func (AuthInternal) New(te auth.TokenEncoder) *AuthInternal { +func (AuthInternal) New() *AuthInternal { return &AuthInternal{ - tokenEncoder: te, + tokenEncoder: auth.DefaultJwtHandler, authSvc: service.DefaultAuth, } } diff --git a/system/rest/router.go b/system/rest/router.go index 4f380445f..1ab5915fc 100644 --- a/system/rest/router.go +++ b/system/rest/router.go @@ -7,15 +7,15 @@ import ( "github.com/crusttech/crust/system/rest/handlers" ) -func MountRoutes(jwtEncoder auth.TokenEncoder) func(chi.Router) { +func MountRoutes() func(chi.Router) { // Initialize handers & controllers. return func(r chi.Router) { - NewSocial(jwtEncoder).MountRoutes(r) + NewSocial().MountRoutes(r) // Provide raw `/auth` handlers - handlers.NewAuth((Auth{}).New(jwtEncoder)).MountRoutes(r) + handlers.NewAuth((Auth{}).New()).MountRoutes(r) - handlers.NewAuthInternal((AuthInternal{}).New(jwtEncoder)).MountRoutes(r) + handlers.NewAuthInternal((AuthInternal{}).New()).MountRoutes(r) // Protect all _private_ routes r.Group(func(r chi.Router) { diff --git a/system/routes.go b/system/routes.go index 8e5df70c4..8ec8fdc05 100644 --- a/system/routes.go +++ b/system/routes.go @@ -12,20 +12,24 @@ import ( "github.com/crusttech/crust/system/rest" ) -func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux { +func Routes(ctx context.Context) *chi.Mux { r := chi.NewRouter() middleware.Mount(ctx, r, flags.http) - MountRoutes(ctx, r, th) + MountRoutes(ctx, r) routes.Print(r) middleware.MountSystemRoutes(ctx, r, flags.http) return r } -func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) { +func MountRoutes(ctx context.Context, r chi.Router) { // Only protect application routes with JWT r.Group(func(r chi.Router) { - r.Use(th.Verifier(), th.Authenticator()) - mountRoutes(r, flags.http, rest.MountRoutes(th)) + r.Use( + auth.DefaultJwtHandler.Verifier(), + auth.DefaultJwtHandler.Authenticator(), + ) + + mountRoutes(r, flags.http, rest.MountRoutes()) }) } diff --git a/system/routes_test.go b/system/routes_test.go index fac7ddd91..c17299c3b 100644 --- a/system/routes_test.go +++ b/system/routes_test.go @@ -56,7 +56,7 @@ func TestUsers(t *testing.T) { jwtAuth, err := auth.JWT(jwtSecret, 600) test.NoError(t, err, "Error initializing: %v") - routes := Routes(ctx, jwtAuth) + routes := Routes(ctx) // Send check request with invalid JWT token. { diff --git a/system/start.go b/system/start.go index 38a66e483..75f03ddce 100644 --- a/system/start.go +++ b/system/start.go @@ -21,16 +21,16 @@ import ( "github.com/crusttech/crust/system/service" ) -func Init(ctx context.Context) error { +func Init(ctx context.Context) (err error) { // validate configuration - if err := flags.Validate(); err != nil { - return err + if err = flags.Validate(); err != nil { + return } mail.SetupDialer(flags.smtp) - if err := InitDatabase(ctx); err != nil { - return err + if err = InitDatabase(ctx); err != nil { + return } // configure resputil options @@ -42,9 +42,16 @@ func Init(ctx context.Context) error { }, }) + // Use JWT secret for hmac signer for now + auth.DefaultSigner = auth.HmacSigner(flags.jwt.Secret) + auth.DefaultJwtHandler, err = auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) + if err != nil { + return + } + // Don't change this, it needs database connection - if err := service.Init(); err != nil { - return err + if err = service.Init(); err != nil { + return } return nil @@ -83,12 +90,7 @@ func StartRestAPI(ctx context.Context) error { go metrics.NewMonitor(flags.monitor.Interval) } - jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) - if err != nil { - return errors.Wrap(err, "Error creating JWT Auth") - } - - go http.Serve(listener, Routes(ctx, jwtAuth)) + go http.Serve(listener, Routes(ctx)) <-ctx.Done() return nil