diff --git a/cmd/crm/main.go b/cmd/crm/main.go index 26db9173a..ecd9e0d88 100644 --- a/cmd/crm/main.go +++ b/cmd/crm/main.go @@ -9,7 +9,6 @@ import ( _ "github.com/joho/godotenv/autoload" "github.com/namsral/flag" - "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/subscription" "github.com/crusttech/crust/internal/version" @@ -28,7 +27,6 @@ func main() { crm.Flags("crm") system.Flags("system") - auth.Flags() subscription.Flags() flag.Parse() diff --git a/cmd/crust/main.go b/cmd/crust/main.go index 03114c6ce..a5cb5738a 100644 --- a/cmd/crust/main.go +++ b/cmd/crust/main.go @@ -3,11 +3,10 @@ package main import ( "log" "net" + "net/http" "os" "path" - "net/http" - context "github.com/SentimensRG/ctx" "github.com/SentimensRG/ctx/sigctx" "github.com/go-chi/chi" @@ -15,10 +14,10 @@ import ( "github.com/namsral/flag" crm "github.com/crusttech/crust/crm" + "github.com/crusttech/crust/internal/auth" messaging "github.com/crusttech/crust/messaging" system "github.com/crusttech/crust/system" - "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/config" "github.com/crusttech/crust/internal/metrics" "github.com/crusttech/crust/internal/middleware" @@ -60,7 +59,8 @@ func main() { messaging.Flags("messaging") system.Flags("system") - auth.Flags() + authJwtFlags := new(config.JWT).Init() + subscription.Flags() flag.Parse() @@ -103,15 +103,20 @@ func main() { // logging, cors and such middleware.Mount(ctx, r, flags.http) + jwtAuth, err := auth.JWT(authJwtFlags.Secret, authJwtFlags.Expiry) + if err != nil { + log.Fatalf("Error creating JWT Auth: %v", err) + } + r.Route("/api", func(r chi.Router) { r.Route("/crm", func(r chi.Router) { - crm.MountRoutes(ctx, r) + crm.MountRoutes(ctx, r, jwtAuth) }) r.Route("/messaging", func(r chi.Router) { - messaging.MountRoutes(ctx, r) + messaging.MountRoutes(ctx, r, jwtAuth) }) r.Route("/system", func(r chi.Router) { - system.MountRoutes(ctx, r) + system.MountRoutes(ctx, r, jwtAuth) }) middleware.MountSystemRoutes(ctx, r, flags.http) }) diff --git a/cmd/messaging/main.go b/cmd/messaging/main.go index 2bb5c2609..8580958f6 100644 --- a/cmd/messaging/main.go +++ b/cmd/messaging/main.go @@ -9,7 +9,6 @@ import ( _ "github.com/joho/godotenv/autoload" "github.com/namsral/flag" - "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/subscription" "github.com/crusttech/crust/internal/version" @@ -28,7 +27,6 @@ func main() { messaging.Flags("messaging") system.Flags("system") - auth.Flags() subscription.Flags() flag.Parse() diff --git a/cmd/system/main.go b/cmd/system/main.go index 6d066a0a5..270e666de 100644 --- a/cmd/system/main.go +++ b/cmd/system/main.go @@ -9,7 +9,6 @@ import ( _ "github.com/joho/godotenv/autoload" "github.com/namsral/flag" - "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/subscription" "github.com/crusttech/crust/internal/version" @@ -26,7 +25,6 @@ func main() { system.Flags("system") - auth.Flags() subscription.Flags() flag.Parse() diff --git a/crm/flags.go b/crm/flags.go index 934d922aa..1621055eb 100644 --- a/crm/flags.go +++ b/crm/flags.go @@ -14,6 +14,7 @@ type ( monitor *config.Monitor db *config.Database repository *repository.Flags + jwt *config.JWT } ) @@ -55,5 +56,6 @@ func Flags(prefix ...string) { new(config.Monitor).Init(prefix...), new(config.Database).Init(prefix...), new(repository.Flags).Init(prefix...), + new(config.JWT).Init(), } } diff --git a/crm/routes.go b/crm/routes.go index fa0192621..763d2c9f4 100644 --- a/crm/routes.go +++ b/crm/routes.go @@ -6,24 +6,25 @@ import ( "github.com/go-chi/chi" "github.com/crusttech/crust/crm/rest" + "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/config" "github.com/crusttech/crust/internal/middleware" "github.com/crusttech/crust/internal/routes" ) -func Routes(ctx context.Context) *chi.Mux { +func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux { r := chi.NewRouter() middleware.Mount(ctx, r, flags.http) - MountRoutes(ctx, r) + MountRoutes(ctx, r, th) routes.Print(r) middleware.MountSystemRoutes(ctx, r, flags.http) return r } -func MountRoutes(ctx context.Context, r chi.Router) { +func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) { // Only protect application routes with JWT r.Group(func(r chi.Router) { - r.Use(jwtVerifier, jwtAuthenticator) + r.Use(th.Verifier(), th.Authenticator()) mountRoutes(r, flags.http, rest.MountRoutes()) }) } diff --git a/crm/start.go b/crm/start.go index d9231070c..d9c367667 100644 --- a/crm/start.go +++ b/crm/start.go @@ -18,25 +18,11 @@ import ( "github.com/crusttech/crust/internal/metrics" ) -var ( - jwtVerifier func(http.Handler) http.Handler - jwtAuthenticator func(http.Handler) http.Handler - jwtEncoder auth.TokenEncoder -) - func Init(ctx context.Context) error { // validate configuration if err := flags.Validate(); err != nil { return err } - // JWT Auth - if jwtAuth, err := auth.JWT(); err != nil { - return errors.Wrap(err, "Error creating JWT Auth object") - } else { - jwtEncoder = jwtAuth - jwtVerifier = jwtAuth.Verifier() - jwtAuthenticator = jwtAuth.Authenticator() - } mail.SetupDialer(flags.smtp) @@ -83,7 +69,12 @@ func StartRestAPI(ctx context.Context) error { go metrics.NewMonitor(flags.monitor.Interval) } - go http.Serve(listener, Routes(ctx)) + jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) + if err != nil { + return errors.Wrap(err, "Error creating JWT Auth") + } + + go http.Serve(listener, Routes(ctx, jwtAuth)) <-ctx.Done() return nil diff --git a/internal/auth/flags.go b/internal/auth/flags.go deleted file mode 100644 index 4bcb0a138..000000000 --- a/internal/auth/flags.go +++ /dev/null @@ -1,38 +0,0 @@ -package auth - -import ( - "github.com/crusttech/crust/internal/config" -) - -type ( - localFlags struct { - jwt *config.JWT - } -) - -var flags *localFlags - -// Flags matches signature for main() -func Flags(prefix ...string) { - new(localFlags).Init(prefix...) -} - -func (f *localFlags) Validate() error { - if flags == nil { - return ErrConfigError.New() - } - if err := f.jwt.Validate(); err != nil { - return err - } - return nil -} - -func (f *localFlags) Init(prefix ...string) *localFlags { - if flags != nil { - return flags - } - flags = &localFlags{ - new(config.JWT).Init(prefix...), - } - return flags -} diff --git a/internal/auth/interfaces.go b/internal/auth/interfaces.go index ad10152cc..2a1c758fe 100644 --- a/internal/auth/interfaces.go +++ b/internal/auth/interfaces.go @@ -1,5 +1,9 @@ package auth +import ( + "net/http" +) + type ( Identifiable interface { Identity() uint64 @@ -9,4 +13,10 @@ type ( TokenEncoder interface { Encode(identity Identifiable) string } + + TokenHandler interface { + Encode(identity Identifiable) string + Verifier() func(http.Handler) http.Handler + Authenticator() func(http.Handler) http.Handler + } ) diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index a90b052e6..bc1d75cda 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -1,35 +1,35 @@ package auth import ( - "log" "net/http" "strconv" "time" "github.com/dgrijalva/jwt-go" "github.com/go-chi/jwtauth" + "github.com/pkg/errors" "github.com/titpetric/factory/resputil" ) -type token struct { - expiry int64 - cookieDomain string - tokenAuth *jwtauth.JWTAuth -} - -func JWT() (*token, error) { - if err := flags.Validate(); err != nil { - return nil, err +type ( + token struct { + expiry int64 + tokenAuth *jwtauth.JWTAuth } - jwt := &token{ - expiry: flags.jwt.Expiry, - cookieDomain: flags.jwt.CookieDomain, - tokenAuth: jwtauth.New("HS256", []byte(flags.jwt.Secret), nil), + jwtSettingsGetter interface { + GetGlobalString(name string) (out string, err error) + } +) + +func JWT(secret string, expiry int64) (jwt *token, err error) { + if len(secret) == 0 { + return nil, errors.New("JWT secret missing") } - if flags.jwt.DebugToken { - log.Println("DEBUG JWT TOKEN:", jwt.Encode(NewIdentity(1))) + jwt = &token{ + expiry: expiry, + tokenAuth: jwtauth.New("HS256", []byte(secret), nil), } return jwt, nil diff --git a/internal/config/jwt.go b/internal/config/jwt.go index cb7892498..dfd760232 100644 --- a/internal/config/jwt.go +++ b/internal/config/jwt.go @@ -7,10 +7,8 @@ import ( type ( JWT struct { - Secret string - Expiry int64 - DebugToken bool - CookieDomain string + Secret string + Expiry int64 } ) @@ -34,7 +32,5 @@ func (*JWT) Init(prefix ...string) *JWT { jwt = new(JWT) flag.StringVar(&jwt.Secret, "auth-jwt-secret", "", "JWT Secret") flag.Int64Var(&jwt.Expiry, "auth-jwt-expiry", 3600, "JWT Expiration in minutes") - flag.StringVar(&jwt.CookieDomain, "auth-jwt-cookie-domain", "", "JWT Cookie domain") - flag.BoolVar(&jwt.DebugToken, "auth-jwt-debug", false, "Generate debug JWT key") return jwt } diff --git a/messaging/flags.go b/messaging/flags.go index 2bb5df12d..9e9db6b72 100644 --- a/messaging/flags.go +++ b/messaging/flags.go @@ -14,6 +14,7 @@ type ( monitor *config.Monitor db *config.Database repository *repository.Flags + jwt *config.JWT } ) @@ -55,5 +56,6 @@ func Flags(prefix ...string) { new(config.Monitor).Init(prefix...), new(config.Database).Init(prefix...), new(repository.Flags).Init(prefix...), + new(config.JWT).Init(), } } diff --git a/messaging/routes.go b/messaging/routes.go index f1e7dcccb..e6b6e2176 100644 --- a/messaging/routes.go +++ b/messaging/routes.go @@ -5,6 +5,7 @@ import ( "github.com/go-chi/chi" + "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/config" "github.com/crusttech/crust/internal/middleware" "github.com/crusttech/crust/internal/routes" @@ -12,19 +13,19 @@ import ( "github.com/crusttech/crust/messaging/websocket" ) -func Routes(ctx context.Context) *chi.Mux { +func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux { r := chi.NewRouter() middleware.Mount(ctx, r, flags.http) - MountRoutes(ctx, r) + MountRoutes(ctx, r, th) routes.Print(r) middleware.MountSystemRoutes(ctx, r, flags.http) return r } -func MountRoutes(ctx context.Context, r chi.Router) { +func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) { // Only protect application routes with JWT r.Group(func(r chi.Router) { - r.Use(jwtVerifier, jwtAuthenticator) + r.Use(th.Verifier(), th.Authenticator()) mountRoutes(r, flags.http, rest.MountRoutes(), websocket.MountRoutes(ctx, flags.repository)) }) } diff --git a/messaging/start.go b/messaging/start.go index 16d684e46..9ca241989 100644 --- a/messaging/start.go +++ b/messaging/start.go @@ -18,25 +18,11 @@ import ( "github.com/crusttech/crust/messaging/internal/service" ) -var ( - jwtVerifier func(http.Handler) http.Handler - jwtAuthenticator func(http.Handler) http.Handler - jwtEncoder auth.TokenEncoder -) - func Init(ctx context.Context) error { // validate configuration if err := flags.Validate(); err != nil { return err } - // JWT Auth - if jwtAuth, err := auth.JWT(); err != nil { - return errors.Wrap(err, "Error creating JWT Auth object") - } else { - jwtEncoder = jwtAuth - jwtVerifier = jwtAuth.Verifier() - jwtAuthenticator = jwtAuth.Authenticator() - } mail.SetupDialer(flags.smtp) @@ -85,7 +71,12 @@ func StartRestAPI(ctx context.Context) error { go metrics.NewMonitor(flags.monitor.Interval) } - go http.Serve(listener, Routes(ctx)) + jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) + if err != nil { + return errors.Wrap(err, "Error creating JWT Auth") + } + + go http.Serve(listener, Routes(ctx, jwtAuth)) <-ctx.Done() return nil diff --git a/system/flags.go b/system/flags.go index 97196f15f..e78fce8af 100644 --- a/system/flags.go +++ b/system/flags.go @@ -12,6 +12,7 @@ type ( http *config.HTTP monitor *config.Monitor db *config.Database + jwt *config.JWT } ) @@ -48,5 +49,6 @@ func Flags(prefix ...string) { new(config.HTTP).Init(prefix...), new(config.Monitor).Init(prefix...), new(config.Database).Init(prefix...), + new(config.JWT).Init(), } } diff --git a/system/routes.go b/system/routes.go index 4f61896d7..8e5df70c4 100644 --- a/system/routes.go +++ b/system/routes.go @@ -5,26 +5,27 @@ import ( "github.com/go-chi/chi" + "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/config" "github.com/crusttech/crust/internal/middleware" "github.com/crusttech/crust/internal/routes" "github.com/crusttech/crust/system/rest" ) -func Routes(ctx context.Context) *chi.Mux { +func Routes(ctx context.Context, th auth.TokenHandler) *chi.Mux { r := chi.NewRouter() middleware.Mount(ctx, r, flags.http) - MountRoutes(ctx, r) + MountRoutes(ctx, r, th) routes.Print(r) middleware.MountSystemRoutes(ctx, r, flags.http) return r } -func MountRoutes(ctx context.Context, r chi.Router) { +func MountRoutes(ctx context.Context, r chi.Router, th auth.TokenHandler) { // Only protect application routes with JWT r.Group(func(r chi.Router) { - r.Use(jwtVerifier, jwtAuthenticator) - mountRoutes(r, flags.http, rest.MountRoutes(jwtEncoder)) + r.Use(th.Verifier(), th.Authenticator()) + mountRoutes(r, flags.http, rest.MountRoutes(th)) }) } diff --git a/system/start.go b/system/start.go index 3df0b17fc..38a66e483 100644 --- a/system/start.go +++ b/system/start.go @@ -21,25 +21,11 @@ import ( "github.com/crusttech/crust/system/service" ) -var ( - jwtVerifier func(http.Handler) http.Handler - jwtAuthenticator func(http.Handler) http.Handler - jwtEncoder auth.TokenEncoder -) - func Init(ctx context.Context) error { // validate configuration if err := flags.Validate(); err != nil { return err } - // JWT Auth - if jwtAuth, err := auth.JWT(); err != nil { - return errors.Wrap(err, "Error creating JWT Auth object") - } else { - jwtEncoder = jwtAuth - jwtVerifier = jwtAuth.Verifier() - jwtAuthenticator = jwtAuth.Authenticator() - } mail.SetupDialer(flags.smtp) @@ -82,10 +68,10 @@ func InitDatabase(ctx context.Context) error { func StartRestAPI(ctx context.Context) error { // Load settings from the database, // for now, only at start-up time. - settingService := settings.NewService(settings.NewRepository(repository.DB(ctx), "sys_settings")) + settingService := settings.NewService(settings.NewRepository(repository.DB(ctx), "sys_settings")).With(ctx) - // Setup goth/social authentication - external.Init(settingService.With(ctx)) + // Setup goth/external authentication + external.Init(settingService) log.Println("Starting http server on address " + flags.http.Addr) listener, err := net.Listen("tcp", flags.http.Addr) @@ -97,7 +83,12 @@ func StartRestAPI(ctx context.Context) error { go metrics.NewMonitor(flags.monitor.Interval) } - go http.Serve(listener, Routes(ctx)) + jwtAuth, err := auth.JWT(flags.jwt.Secret, flags.jwt.Expiry) + if err != nil { + return errors.Wrap(err, "Error creating JWT Auth") + } + + go http.Serve(listener, Routes(ctx, jwtAuth)) <-ctx.Done() return nil