Refactor JWT init flow
This commit is contained in:
parent
c4104488e5
commit
11def550c6
@ -9,7 +9,6 @@ import (
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/namsral/flag"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/subscription"
|
||||
"github.com/crusttech/crust/internal/version"
|
||||
|
||||
@ -28,7 +27,6 @@ func main() {
|
||||
crm.Flags("crm")
|
||||
system.Flags("system")
|
||||
|
||||
auth.Flags()
|
||||
subscription.Flags()
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@ -3,11 +3,10 @@ package main
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"net/http"
|
||||
|
||||
context "github.com/SentimensRG/ctx"
|
||||
"github.com/SentimensRG/ctx/sigctx"
|
||||
"github.com/go-chi/chi"
|
||||
@ -15,10 +14,10 @@ import (
|
||||
"github.com/namsral/flag"
|
||||
|
||||
crm "github.com/crusttech/crust/crm"
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
messaging "github.com/crusttech/crust/messaging"
|
||||
system "github.com/crusttech/crust/system"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/config"
|
||||
"github.com/crusttech/crust/internal/metrics"
|
||||
"github.com/crusttech/crust/internal/middleware"
|
||||
@ -60,7 +59,8 @@ func main() {
|
||||
messaging.Flags("messaging")
|
||||
system.Flags("system")
|
||||
|
||||
auth.Flags()
|
||||
authJwtFlags := new(config.JWT).Init()
|
||||
|
||||
subscription.Flags()
|
||||
|
||||
flag.Parse()
|
||||
@ -103,15 +103,20 @@ func main() {
|
||||
// logging, cors and such
|
||||
middleware.Mount(ctx, r, flags.http)
|
||||
|
||||
jwtAuth, err := auth.JWT(authJwtFlags.Secret, authJwtFlags.Expiry)
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating JWT Auth: %v", err)
|
||||
}
|
||||
|
||||
r.Route("/api", func(r chi.Router) {
|
||||
r.Route("/crm", func(r chi.Router) {
|
||||
crm.MountRoutes(ctx, r)
|
||||
crm.MountRoutes(ctx, r, jwtAuth)
|
||||
})
|
||||
r.Route("/messaging", func(r chi.Router) {
|
||||
messaging.MountRoutes(ctx, r)
|
||||
messaging.MountRoutes(ctx, r, jwtAuth)
|
||||
})
|
||||
r.Route("/system", func(r chi.Router) {
|
||||
system.MountRoutes(ctx, r)
|
||||
system.MountRoutes(ctx, r, jwtAuth)
|
||||
})
|
||||
middleware.MountSystemRoutes(ctx, r, flags.http)
|
||||
})
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/namsral/flag"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/subscription"
|
||||
"github.com/crusttech/crust/internal/version"
|
||||
|
||||
@ -28,7 +27,6 @@ func main() {
|
||||
messaging.Flags("messaging")
|
||||
system.Flags("system")
|
||||
|
||||
auth.Flags()
|
||||
subscription.Flags()
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/namsral/flag"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/subscription"
|
||||
"github.com/crusttech/crust/internal/version"
|
||||
|
||||
@ -26,7 +25,6 @@ func main() {
|
||||
|
||||
system.Flags("system")
|
||||
|
||||
auth.Flags()
|
||||
subscription.Flags()
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@ -14,6 +14,7 @@ type (
|
||||
monitor *config.Monitor
|
||||
db *config.Database
|
||||
repository *repository.Flags
|
||||
jwt *config.JWT
|
||||
}
|
||||
)
|
||||
|
||||
@ -55,5 +56,6 @@ func Flags(prefix ...string) {
|
||||
new(config.Monitor).Init(prefix...),
|
||||
new(config.Database).Init(prefix...),
|
||||
new(repository.Flags).Init(prefix...),
|
||||
new(config.JWT).Init(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,24 +6,25 @@ import (
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/crusttech/crust/crm/rest"
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/config"
|
||||
"github.com/crusttech/crust/internal/middleware"
|
||||
"github.com/crusttech/crust/internal/routes"
|
||||
)
|
||||
|
||||
func Routes(ctx context.Context) *chi.Mux {
|
||||
func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux {
|
||||
r := chi.NewRouter()
|
||||
middleware.Mount(ctx, r, flags.http)
|
||||
MountRoutes(ctx, r)
|
||||
MountRoutes(ctx, r, th)
|
||||
routes.Print(r)
|
||||
middleware.MountSystemRoutes(ctx, r, flags.http)
|
||||
return r
|
||||
}
|
||||
|
||||
func MountRoutes(ctx context.Context, r chi.Router) {
|
||||
func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) {
|
||||
// Only protect application routes with JWT
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(jwtVerifier, jwtAuthenticator)
|
||||
r.Use(th.Verifier(), th.Authenticator())
|
||||
mountRoutes(r, flags.http, rest.MountRoutes())
|
||||
})
|
||||
}
|
||||
|
||||
21
crm/start.go
21
crm/start.go
@ -18,25 +18,11 @@ import (
|
||||
"github.com/crusttech/crust/internal/metrics"
|
||||
)
|
||||
|
||||
var (
|
||||
jwtVerifier func(http.Handler) http.Handler
|
||||
jwtAuthenticator func(http.Handler) http.Handler
|
||||
jwtEncoder auth.TokenEncoder
|
||||
)
|
||||
|
||||
func Init(ctx context.Context) error {
|
||||
// validate configuration
|
||||
if err := flags.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
// JWT Auth
|
||||
if jwtAuth, err := auth.JWT(); err != nil {
|
||||
return errors.Wrap(err, "Error creating JWT Auth object")
|
||||
} else {
|
||||
jwtEncoder = jwtAuth
|
||||
jwtVerifier = jwtAuth.Verifier()
|
||||
jwtAuthenticator = jwtAuth.Authenticator()
|
||||
}
|
||||
|
||||
mail.SetupDialer(flags.smtp)
|
||||
|
||||
@ -83,7 +69,12 @@ func StartRestAPI(ctx context.Context) error {
|
||||
go metrics.NewMonitor(flags.monitor.Interval)
|
||||
}
|
||||
|
||||
go http.Serve(listener, Routes(ctx))
|
||||
jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error creating JWT Auth")
|
||||
}
|
||||
|
||||
go http.Serve(listener, Routes(ctx, jwtAuth))
|
||||
<-ctx.Done()
|
||||
|
||||
return nil
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/crusttech/crust/internal/config"
|
||||
)
|
||||
|
||||
type (
|
||||
localFlags struct {
|
||||
jwt *config.JWT
|
||||
}
|
||||
)
|
||||
|
||||
var flags *localFlags
|
||||
|
||||
// Flags matches signature for main()
|
||||
func Flags(prefix ...string) {
|
||||
new(localFlags).Init(prefix...)
|
||||
}
|
||||
|
||||
func (f *localFlags) Validate() error {
|
||||
if flags == nil {
|
||||
return ErrConfigError.New()
|
||||
}
|
||||
if err := f.jwt.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *localFlags) Init(prefix ...string) *localFlags {
|
||||
if flags != nil {
|
||||
return flags
|
||||
}
|
||||
flags = &localFlags{
|
||||
new(config.JWT).Init(prefix...),
|
||||
}
|
||||
return flags
|
||||
}
|
||||
@ -1,5 +1,9 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
Identifiable interface {
|
||||
Identity() uint64
|
||||
@ -9,4 +13,10 @@ type (
|
||||
TokenEncoder interface {
|
||||
Encode(identity Identifiable) string
|
||||
}
|
||||
|
||||
TokenHandler interface {
|
||||
Encode(identity Identifiable) string
|
||||
Verifier() func(http.Handler) http.Handler
|
||||
Authenticator() func(http.Handler) http.Handler
|
||||
}
|
||||
)
|
||||
|
||||
@ -1,35 +1,35 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory/resputil"
|
||||
)
|
||||
|
||||
type token struct {
|
||||
expiry int64
|
||||
cookieDomain string
|
||||
tokenAuth *jwtauth.JWTAuth
|
||||
}
|
||||
|
||||
func JWT() (*token, error) {
|
||||
if err := flags.Validate(); err != nil {
|
||||
return nil, err
|
||||
type (
|
||||
token struct {
|
||||
expiry int64
|
||||
tokenAuth *jwtauth.JWTAuth
|
||||
}
|
||||
|
||||
jwt := &token{
|
||||
expiry: flags.jwt.Expiry,
|
||||
cookieDomain: flags.jwt.CookieDomain,
|
||||
tokenAuth: jwtauth.New("HS256", []byte(flags.jwt.Secret), nil),
|
||||
jwtSettingsGetter interface {
|
||||
GetGlobalString(name string) (out string, err error)
|
||||
}
|
||||
)
|
||||
|
||||
func JWT(secret string, expiry int64) (jwt *token, err error) {
|
||||
if len(secret) == 0 {
|
||||
return nil, errors.New("JWT secret missing")
|
||||
}
|
||||
|
||||
if flags.jwt.DebugToken {
|
||||
log.Println("DEBUG JWT TOKEN:", jwt.Encode(NewIdentity(1)))
|
||||
jwt = &token{
|
||||
expiry: expiry,
|
||||
tokenAuth: jwtauth.New("HS256", []byte(secret), nil),
|
||||
}
|
||||
|
||||
return jwt, nil
|
||||
|
||||
@ -7,10 +7,8 @@ import (
|
||||
|
||||
type (
|
||||
JWT struct {
|
||||
Secret string
|
||||
Expiry int64
|
||||
DebugToken bool
|
||||
CookieDomain string
|
||||
Secret string
|
||||
Expiry int64
|
||||
}
|
||||
)
|
||||
|
||||
@ -34,7 +32,5 @@ func (*JWT) Init(prefix ...string) *JWT {
|
||||
jwt = new(JWT)
|
||||
flag.StringVar(&jwt.Secret, "auth-jwt-secret", "", "JWT Secret")
|
||||
flag.Int64Var(&jwt.Expiry, "auth-jwt-expiry", 3600, "JWT Expiration in minutes")
|
||||
flag.StringVar(&jwt.CookieDomain, "auth-jwt-cookie-domain", "", "JWT Cookie domain")
|
||||
flag.BoolVar(&jwt.DebugToken, "auth-jwt-debug", false, "Generate debug JWT key")
|
||||
return jwt
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ type (
|
||||
monitor *config.Monitor
|
||||
db *config.Database
|
||||
repository *repository.Flags
|
||||
jwt *config.JWT
|
||||
}
|
||||
)
|
||||
|
||||
@ -55,5 +56,6 @@ func Flags(prefix ...string) {
|
||||
new(config.Monitor).Init(prefix...),
|
||||
new(config.Database).Init(prefix...),
|
||||
new(repository.Flags).Init(prefix...),
|
||||
new(config.JWT).Init(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/config"
|
||||
"github.com/crusttech/crust/internal/middleware"
|
||||
"github.com/crusttech/crust/internal/routes"
|
||||
@ -12,19 +13,19 @@ import (
|
||||
"github.com/crusttech/crust/messaging/websocket"
|
||||
)
|
||||
|
||||
func Routes(ctx context.Context) *chi.Mux {
|
||||
func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux {
|
||||
r := chi.NewRouter()
|
||||
middleware.Mount(ctx, r, flags.http)
|
||||
MountRoutes(ctx, r)
|
||||
MountRoutes(ctx, r, th)
|
||||
routes.Print(r)
|
||||
middleware.MountSystemRoutes(ctx, r, flags.http)
|
||||
return r
|
||||
}
|
||||
|
||||
func MountRoutes(ctx context.Context, r chi.Router) {
|
||||
func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) {
|
||||
// Only protect application routes with JWT
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(jwtVerifier, jwtAuthenticator)
|
||||
r.Use(th.Verifier(), th.Authenticator())
|
||||
mountRoutes(r, flags.http, rest.MountRoutes(), websocket.MountRoutes(ctx, flags.repository))
|
||||
})
|
||||
}
|
||||
|
||||
@ -18,25 +18,11 @@ import (
|
||||
"github.com/crusttech/crust/messaging/internal/service"
|
||||
)
|
||||
|
||||
var (
|
||||
jwtVerifier func(http.Handler) http.Handler
|
||||
jwtAuthenticator func(http.Handler) http.Handler
|
||||
jwtEncoder auth.TokenEncoder
|
||||
)
|
||||
|
||||
func Init(ctx context.Context) error {
|
||||
// validate configuration
|
||||
if err := flags.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
// JWT Auth
|
||||
if jwtAuth, err := auth.JWT(); err != nil {
|
||||
return errors.Wrap(err, "Error creating JWT Auth object")
|
||||
} else {
|
||||
jwtEncoder = jwtAuth
|
||||
jwtVerifier = jwtAuth.Verifier()
|
||||
jwtAuthenticator = jwtAuth.Authenticator()
|
||||
}
|
||||
|
||||
mail.SetupDialer(flags.smtp)
|
||||
|
||||
@ -85,7 +71,12 @@ func StartRestAPI(ctx context.Context) error {
|
||||
go metrics.NewMonitor(flags.monitor.Interval)
|
||||
}
|
||||
|
||||
go http.Serve(listener, Routes(ctx))
|
||||
jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error creating JWT Auth")
|
||||
}
|
||||
|
||||
go http.Serve(listener, Routes(ctx, jwtAuth))
|
||||
<-ctx.Done()
|
||||
|
||||
return nil
|
||||
|
||||
@ -12,6 +12,7 @@ type (
|
||||
http *config.HTTP
|
||||
monitor *config.Monitor
|
||||
db *config.Database
|
||||
jwt *config.JWT
|
||||
}
|
||||
)
|
||||
|
||||
@ -48,5 +49,6 @@ func Flags(prefix ...string) {
|
||||
new(config.HTTP).Init(prefix...),
|
||||
new(config.Monitor).Init(prefix...),
|
||||
new(config.Database).Init(prefix...),
|
||||
new(config.JWT).Init(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,26 +5,27 @@ import (
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/config"
|
||||
"github.com/crusttech/crust/internal/middleware"
|
||||
"github.com/crusttech/crust/internal/routes"
|
||||
"github.com/crusttech/crust/system/rest"
|
||||
)
|
||||
|
||||
func Routes(ctx context.Context) *chi.Mux {
|
||||
func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux {
|
||||
r := chi.NewRouter()
|
||||
middleware.Mount(ctx, r, flags.http)
|
||||
MountRoutes(ctx, r)
|
||||
MountRoutes(ctx, r, th)
|
||||
routes.Print(r)
|
||||
middleware.MountSystemRoutes(ctx, r, flags.http)
|
||||
return r
|
||||
}
|
||||
|
||||
func MountRoutes(ctx context.Context, r chi.Router) {
|
||||
func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) {
|
||||
// Only protect application routes with JWT
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(jwtVerifier, jwtAuthenticator)
|
||||
mountRoutes(r, flags.http, rest.MountRoutes(jwtEncoder))
|
||||
r.Use(th.Verifier(), th.Authenticator())
|
||||
mountRoutes(r, flags.http, rest.MountRoutes(th))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -21,25 +21,11 @@ import (
|
||||
"github.com/crusttech/crust/system/service"
|
||||
)
|
||||
|
||||
var (
|
||||
jwtVerifier func(http.Handler) http.Handler
|
||||
jwtAuthenticator func(http.Handler) http.Handler
|
||||
jwtEncoder auth.TokenEncoder
|
||||
)
|
||||
|
||||
func Init(ctx context.Context) error {
|
||||
// validate configuration
|
||||
if err := flags.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
// JWT Auth
|
||||
if jwtAuth, err := auth.JWT(); err != nil {
|
||||
return errors.Wrap(err, "Error creating JWT Auth object")
|
||||
} else {
|
||||
jwtEncoder = jwtAuth
|
||||
jwtVerifier = jwtAuth.Verifier()
|
||||
jwtAuthenticator = jwtAuth.Authenticator()
|
||||
}
|
||||
|
||||
mail.SetupDialer(flags.smtp)
|
||||
|
||||
@ -82,10 +68,10 @@ func InitDatabase(ctx context.Context) error {
|
||||
func StartRestAPI(ctx context.Context) error {
|
||||
// Load settings from the database,
|
||||
// for now, only at start-up time.
|
||||
settingService := settings.NewService(settings.NewRepository(repository.DB(ctx), "sys_settings"))
|
||||
settingService := settings.NewService(settings.NewRepository(repository.DB(ctx), "sys_settings")).With(ctx)
|
||||
|
||||
// Setup goth/social authentication
|
||||
external.Init(settingService.With(ctx))
|
||||
// Setup goth/external authentication
|
||||
external.Init(settingService)
|
||||
|
||||
log.Println("Starting http server on address " + flags.http.Addr)
|
||||
listener, err := net.Listen("tcp", flags.http.Addr)
|
||||
@ -97,7 +83,12 @@ func StartRestAPI(ctx context.Context) error {
|
||||
go metrics.NewMonitor(flags.monitor.Interval)
|
||||
}
|
||||
|
||||
go http.Serve(listener, Routes(ctx))
|
||||
jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error creating JWT Auth")
|
||||
}
|
||||
|
||||
go http.Serve(listener, Routes(ctx, jwtAuth))
|
||||
<-ctx.Done()
|
||||
|
||||
return nil
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user