diff --git a/auth/docs/src/spec.json b/auth/docs/src/spec.json new file mode 100644 index 000000000..79cf67c72 --- /dev/null +++ b/auth/docs/src/spec.json @@ -0,0 +1,36 @@ +[ + { + "title": "Authentication", + "package": "auth", + "path": "/auth", + "entrypoint": "auth", + "authentication": [], + "apis": [ + { + "name": "login", + "method": "POST", + "title": "User login", + "parameters": { + "post": [ + { "type": "string", "name": "username", "required": true, "title": "Username or email" }, + { "type": "string", "name": "password", "required": true, "title": "Password for user" } + ] + } + }, + { + "name": "create", + "path": "/create", + "method": "POST", + "title": "Create new user", + "parameters": { + "post": [ + { "type": "string", "name": "name", "required": true, "title": "Display name" }, + { "type": "string", "name": "email", "required": true, "title": "Email" }, + { "type": "string", "name": "username", "required": true, "title": "Username" }, + { "type": "string", "name": "password", "required": true, "title": "Password" } + ] + } + } + ] + } +] diff --git a/auth/docs/src/spec/auth.json b/auth/docs/src/spec/auth.json new file mode 100644 index 000000000..4bc01e93e --- /dev/null +++ b/auth/docs/src/spec/auth.json @@ -0,0 +1,68 @@ +{ + "Title": "Authentication", + "Package": "auth", + "Interface": "Auth", + "Struct": null, + "Parameters": null, + "Protocol": "", + "Authentication": [], + "Path": "/auth", + "APIs": [ + { + "Name": "login", + "Method": "POST", + "Title": "User login", + "Path": "/login", + "Parameters": { + "post": [ + { + "name": "username", + "required": true, + "title": "Username or email", + "type": "string" + }, + { + "name": "password", + "required": true, + "title": "Password for user", + "type": "string" + } + ] + } + }, + { + "Name": "create", + "Method": "POST", + "Title": "Create new user", + "Path": "/create", + "Parameters": { + "post": [ + { + "name": "name", + "required": true, + "title": "Display name", + "type": "string" + }, + { + "name": "email", + "required": true, + "title": "Email", + "type": "string" + }, + { + "name": "username", + "required": true, + "title": "Username", + "type": "string" + }, + { + "name": "password", + "required": true, + "title": "Password", + "type": "string" + } + ] + } + } + ] +} \ No newline at end of file diff --git a/auth/flags.go b/auth/flags.go index d5f05e91c..6edc5f006 100644 --- a/auth/flags.go +++ b/auth/flags.go @@ -1,31 +1,52 @@ package auth import ( - "github.com/namsral/flag" "github.com/pkg/errors" ) type ( configuration struct { - jwtSecret string - jwtExpiry int64 - jwtDebug bool + http *httpFlags + db *dbFlags + jwt *jwtFlags } ) -var config configuration +var config *configuration -func (c configuration) validate() error { - if c.jwtSecret == "" { - return errors.New("JWT Secret not set for AUTH") +func (configuration) New() *configuration { + return &configuration{ + new(httpFlags), + new(dbFlags), + new(jwtFlags), } +} +func (c *configuration) validate() error { + if c == nil { + return errors.New("CRM config is not initialized, need to call Flags()") + } + if err := c.http.validate(); err != nil { + return err + } + if err := c.db.validate(); err != nil { + return err + } + if err := c.jwt.validate(); err != nil { + return err + } return nil } -// Flags should be called from main to register flags -func Flags(_ ...string) { - flag.StringVar(&config.jwtSecret, "auth-jwt-secret", "", "JWT Secret") - flag.Int64Var(&config.jwtExpiry, "auth-jwt-expiry", 3600, "JWT Expiration in minutes") - flag.BoolVar(&config.jwtDebug, "auth-jwt-debug", false, "Generate debug JWT key") +func Flags(prefix ...string) { + if config != nil { + return + } + if len(prefix) == 0 { + panic("crm.Flags() needs prefix on first call") + } + config = configuration{}.New() + config.http.flags(prefix...) + config.db.flags(prefix...) + config.jwt.flags(prefix...) } diff --git a/auth/flags_db.go b/auth/flags_db.go new file mode 100644 index 000000000..9b72ddfea --- /dev/null +++ b/auth/flags_db.go @@ -0,0 +1,29 @@ +package auth + +import ( + "github.com/namsral/flag" + "github.com/pkg/errors" +) + +type ( + dbFlags struct { + dsn string + profiler string + } +) + +func (c *dbFlags) validate() error { + if c.dsn == "" { + return errors.New("No DB DSN is set, can't connect to database") + } + return nil +} + +func (c *dbFlags) flags(prefix ...string) { + p := func(s string) string { + return prefix[0] + "-" + s + } + + flag.StringVar(&c.dsn, p("db-dsn"), "crust:crust@tcp(db1:3306)/crust?collation=utf8mb4_general_ci", "DSN for database connection") + flag.StringVar(&c.profiler, p("db-profiler"), "", "Profiler for DB queries (none, stdout)") +} diff --git a/auth/flags_http.go b/auth/flags_http.go new file mode 100644 index 000000000..f4b78ed1e --- /dev/null +++ b/auth/flags_http.go @@ -0,0 +1,35 @@ +package auth + +import ( + "github.com/namsral/flag" + "github.com/pkg/errors" +) + +type ( + httpFlags struct { + addr string + logging bool + pretty bool + tracing bool + metrics bool + } +) + +func (c *httpFlags) validate() error { + if c.addr == "" { + return errors.New("No HTTP Addr is set, can't listen for HTTP") + } + return nil +} + +func (c *httpFlags) flags(prefix ...string) { + p := func(s string) string { + return prefix[0] + "-" + s + } + + flag.StringVar(&c.addr, p("http-addr"), ":3000", "Listen address for HTTP server") + flag.BoolVar(&c.logging, p("http-log"), true, "Enable/disable HTTP request log") + flag.BoolVar(&c.pretty, p("http-pretty-json"), false, "Prettify returned JSON output") + flag.BoolVar(&c.tracing, p("http-error-tracing"), false, "Return error stack frame") + flag.BoolVar(&c.metrics, p("http-metrics"), false, "Provide metrics export for prometheus") +} diff --git a/auth/flags_jwt.go b/auth/flags_jwt.go new file mode 100644 index 000000000..96ee96cd2 --- /dev/null +++ b/auth/flags_jwt.go @@ -0,0 +1,27 @@ +package auth + +import ( + "github.com/namsral/flag" + "github.com/pkg/errors" +) + +type ( + jwtFlags struct { + secret string + expiry int64 + debugToken bool + } +) + +func (c *jwtFlags) validate() error { + if c.secret == "" { + return errors.New("JWT Secret not set for AUTH") + } + return nil +} + +func (c *jwtFlags) flags(prefix ...string) { + flag.StringVar(&c.secret, "auth-jwt-secret", "", "JWT Secret") + flag.Int64Var(&c.expiry, "auth-jwt-expiry", 3600, "JWT Expiration in minutes") + flag.BoolVar(&c.debugToken, "auth-jwt-debug", false, "Generate debug JWT key") +} diff --git a/auth/auth.go b/auth/identity.go similarity index 100% rename from auth/auth.go rename to auth/identity.go diff --git a/auth/context.go b/auth/identity_context.go similarity index 79% rename from auth/context.go rename to auth/identity_context.go index 7c065572a..942ceb2b8 100644 --- a/auth/context.go +++ b/auth/identity_context.go @@ -2,9 +2,11 @@ package auth import ( "context" + "strconv" + + "github.com/crusttech/crust/auth/types" "github.com/go-chi/jwtauth" "github.com/pkg/errors" - "strconv" ) type ( @@ -34,12 +36,12 @@ func getIdentityClaimFromContext(ctx context.Context) (uint64, error) { } } -func SetIdentityToContext(ctx context.Context, identity Identifiable) context.Context { +func SetIdentityToContext(ctx context.Context, identity types.Identifiable) context.Context { return context.WithValue(ctx, identityCtxKey, identity) } -func GetIdentityFromContext(ctx context.Context) Identifiable { - if identity, ok := ctx.Value(identityCtxKey).(Identifiable); ok { +func GetIdentityFromContext(ctx context.Context) types.Identifiable { + if identity, ok := ctx.Value(identityCtxKey).(types.Identifiable); ok { return identity } else { return NewIdentity(0) diff --git a/auth/jwt.go b/auth/jwt.go index b9e631360..9ea2d44af 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -1,12 +1,14 @@ package auth import ( - "github.com/go-chi/jwtauth" - "github.com/titpetric/factory/resputil" "log" "net/http" "strconv" "time" + + "github.com/crusttech/crust/auth/types" + "github.com/go-chi/jwtauth" + "github.com/titpetric/factory/resputil" ) type jwt struct { @@ -18,9 +20,9 @@ func JWT() (*jwt, error) { return nil, err } - jwt := &jwt{tokenAuth: jwtauth.New("HS256", []byte(config.jwtSecret), nil)} + jwt := &jwt{tokenAuth: jwtauth.New("HS256", []byte(config.jwt.secret), nil)} - if config.jwtDebug { + if config.jwt.debugToken { log.Println("DEBUG JWT TOKEN:", jwt.Encode(NewIdentity(1))) } @@ -32,11 +34,11 @@ func (t *jwt) Verifier() func(http.Handler) http.Handler { return jwtauth.Verifier(t.tokenAuth) } -func (t *jwt) Encode(identity Identifiable) string { +func (t *jwt) Encode(identity types.Identifiable) string { // @todo Set expiry claims := jwtauth.Claims{} claims.Set("sub", strconv.FormatUint(identity.Identity(), 10)) - claims.SetExpiryIn(time.Duration(config.jwtExpiry) * time.Minute) + claims.SetExpiryIn(time.Duration(config.jwt.expiry) * time.Minute) _, jwt, _ := t.tokenAuth.Encode(claims) return jwt diff --git a/auth/metrics.go b/auth/metrics.go new file mode 100644 index 000000000..47a4c2332 --- /dev/null +++ b/auth/metrics.go @@ -0,0 +1,20 @@ +package auth + +import ( + "net/http" + + "github.com/766b/chi-prometheus" + "github.com/prometheus/client_golang/prometheus" +) + +type metrics struct{} + +// Middleware is the request logger that provides metrics to prometheus +func (metrics) Middleware(name string) func(http.Handler) http.Handler { + return chiprometheus.NewMiddleware(name) +} + +// Handler exports prometheus metrics for /metrics requests +func (metrics) Handler() http.Handler { + return prometheus.Handler() +} diff --git a/auth/http.go b/auth/middleware.go similarity index 100% rename from auth/http.go rename to auth/middleware.go diff --git a/auth/repository/db.go b/auth/repository/db.go new file mode 100644 index 000000000..828d80514 --- /dev/null +++ b/auth/repository/db.go @@ -0,0 +1,19 @@ +package repository + +import ( + "context" + "github.com/titpetric/factory" +) + +var _db *factory.DB + +func DB(ctxs ...context.Context) *factory.DB { + if _db == nil { + _db = factory.Database.MustGet() + } + for _, ctx := range ctxs { + _db = _db.With(ctx) + break + } + return _db +} diff --git a/auth/repository/error.go b/auth/repository/error.go new file mode 100644 index 000000000..309b60aee --- /dev/null +++ b/auth/repository/error.go @@ -0,0 +1,18 @@ +package repository + +type ( + repositoryError string +) + +const ( + ErrDatabaseError = repositoryError("DatabaseError") + ErrNotImplemented = repositoryError("NotImplemented") +) + +func (e repositoryError) Error() string { + return e.String() +} + +func (e repositoryError) String() string { + return "crust.auth.repository." + string(e) +} diff --git a/auth/repository/repository.go b/auth/repository/repository.go new file mode 100644 index 000000000..505c98925 --- /dev/null +++ b/auth/repository/repository.go @@ -0,0 +1,44 @@ +package repository + +import ( + "context" + "github.com/titpetric/factory" +) + +type ( + repository struct { + ctx context.Context + + // Get database handle + dbh func(ctxs ...context.Context) *factory.DB + } + + Repository interface { + Context() context.Context + DB() *factory.DB + } +) + +// With updates repository and database contexts +func (r *repository) With(ctx context.Context) *repository { + res := &repository{ + ctx: ctx, + dbh: DB, + } + if r != nil { + res.dbh = r.dbh + } + return res +} + +func (r *repository) Context() context.Context { + return r.ctx +} + +// Return context-aware db handle +func (r *repository) db() *factory.DB { + return r.dbh(r.ctx) +} +func (r *repository) DB() *factory.DB { + return r.db() +} diff --git a/auth/repository/user.go b/auth/repository/user.go new file mode 100644 index 000000000..8b4b0e51d --- /dev/null +++ b/auth/repository/user.go @@ -0,0 +1,108 @@ +package repository + +import ( + "context" + "github.com/crusttech/crust/auth/types" + "github.com/titpetric/factory" + "time" +) + +type ( + user struct { + *repository + } + + User interface { + Repository + + With(context.Context) User + + FindUserByUsername(username string) (*types.User, error) + FindUserByID(id uint64) (*types.User, error) + FindUsers(filter *types.UserFilter) ([]*types.User, error) + CreateUser(mod *types.User) (*types.User, error) + UpdateUser(mod *types.User) (*types.User, error) + SuspendUserByID(id uint64) error + UnsuspendUserByID(id uint64) error + DeleteUserByID(id uint64) error + } +) + +const ( + sqlUserScope = "suspended_at IS NULL AND deleted_at IS NULL" + sqlUserSelect = "SELECT * FROM users WHERE " + sqlUserScope + + ErrUserNotFound = repositoryError("UserNotFound") +) + +func NewUser(ctx context.Context) User { + return (&user{}).With(ctx) +} + +func (r *user) With(ctx context.Context) User { + return &user{ + repository: r.repository.With(ctx), + } +} + +func (r *user) FindUserByUsername(username string) (*types.User, error) { + sql := "SELECT * FROM users WHERE username = ? AND " + sqlUserScope + mod := &types.User{} + + return mod, isFound(r.db().Get(mod, sql, username), mod.ID > 0, ErrUserNotFound) +} + +func (r *user) FindUserByID(id uint64) (*types.User, error) { + sql := "SELECT * FROM users WHERE id = ? AND " + sqlUserScope + mod := &types.User{} + + return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrUserNotFound) +} + +func (r *user) FindUsers(filter *types.UserFilter) ([]*types.User, error) { + rval := make([]*types.User, 0) + params := make([]interface{}, 0) + sql := "SELECT * FROM users WHERE " + sqlUserScope + + if filter != nil { + if filter.Query != "" { + sql += " AND username LIKE ?" + params = append(params, filter.Query+"%") + } + + if filter.MembersOfChannel > 0 { + sql += " AND id IN (SELECT rel_user FROM channel_members WHERE rel_channel = ?)" + params = append(params, filter.MembersOfChannel) + } + } + + sql += " ORDER BY username ASC" + + return rval, r.db().Select(&rval, sql, params...) +} + +func (r *user) CreateUser(mod *types.User) (*types.User, error) { + mod.ID = factory.Sonyflake.NextID() + mod.CreatedAt = time.Now() + mod.Meta = coalesceJson(mod.Meta, []byte("{}")) + return mod, r.db().Insert("users", mod) +} + +func (r *user) UpdateUser(mod *types.User) (*types.User, error) { + mod.UpdatedAt = timeNowPtr() + mod.Meta = coalesceJson(mod.Meta, []byte("{}")) + + return mod, r.db().Replace("users", mod) +} + +func (r *user) SuspendUserByID(id uint64) error { + return r.updateColumnByID("users", "suspend_at", time.Now(), id) +} + +func (r *user) UnsuspendUserByID(id uint64) error { + return r.updateColumnByID("users", "suspend_at", nil, id) +} + +func (r *user) DeleteUserByID(id uint64) error { + return r.updateColumnByID("users", "deleted_at", nil, id) +} diff --git a/auth/repository/util.go b/auth/repository/util.go new file mode 100644 index 000000000..0e80f2d4d --- /dev/null +++ b/auth/repository/util.go @@ -0,0 +1,44 @@ +package repository + +import ( + "encoding/json" + "fmt" + "time" +) + +func (r repository) updateColumnByID(tableName, columnName string, value interface{}, id uint64) (err error) { + return exec(r.db().Exec( + fmt.Sprintf("UPDATE %s SET %s = ? WHERE id = ?", tableName, columnName), + value, + id)) +} + +func exec(_ interface{}, err error) error { + return err +} + +// Returns err if set otherwise it returns nerr if not valid +func isFound(err error, valid bool, nerr error) error { + if err != nil { + return err + } else if !valid { + return nerr + } + + return nil +} + +func timeNowPtr() *time.Time { + n := time.Now() + return &n +} + +func coalesceJson(vals ...json.RawMessage) json.RawMessage { + for _, val := range vals { + if val != nil { + return val + } + } + + return nil +} diff --git a/auth/rest/auth.go b/auth/rest/auth.go new file mode 100644 index 000000000..fa80e11b1 --- /dev/null +++ b/auth/rest/auth.go @@ -0,0 +1,50 @@ +package rest + +import ( + "context" + "github.com/crusttech/crust/auth/rest/request" + "github.com/crusttech/crust/auth/service" + "github.com/crusttech/crust/auth/types" + "github.com/pkg/errors" +) + +var _ = errors.Wrap + +type ( + Auth struct { + user service.UserService + token types.TokenEncoder + } +) + +func (Auth) New(credValidator service.UserService, tknEncoder types.TokenEncoder) *Auth { + return &Auth{ + credValidator, + tknEncoder, + } +} + +func (ctrl *Auth) Login(ctx context.Context, r *request.AuthLogin) (interface{}, error) { + return ctrl.tokenize(ctrl.user.With(ctx).ValidateCredentials(r.Username, r.Password)) +} + +func (ctrl *Auth) Create(ctx context.Context, r *request.AuthCreate) (interface{}, error) { + user := &types.User{Username: r.Username} + user.GeneratePassword(r.Password) + return ctrl.tokenize(ctrl.user.With(ctx).Create(user)) +} + +// Wraps user return value and appends JWT +func (ctrl *Auth) tokenize(user *types.User, err error) (interface{}, error) { + if err != nil { + return nil, err + } + + return struct { + JWT string + User *types.User `json:"user"` + }{ + JWT: ctrl.token.Encode(user), + User: user, + }, nil +} diff --git a/auth/rest/handlers/auth.go b/auth/rest/handlers/auth.go new file mode 100644 index 000000000..956ca98a2 --- /dev/null +++ b/auth/rest/handlers/auth.go @@ -0,0 +1,67 @@ +package handlers + +/* + Hello! This file is auto-generated from `docs/src/spec.json`. + + For development: + In order to update the generated files, edit this file under the location, + add your struct fields, imports, API definitions and whatever you want, and: + + 1. run [spec](https://github.com/titpetric/spec) in the same folder, + 2. run `./_gen.php` in this folder. + + You may edit `auth.go`, `auth.util.go` or `auth_test.go` to + implement your API calls, helper functions and tests. The file `auth.go` + is only generated the first time, and will not be overwritten if it exists. +*/ + +import ( + "context" + "github.com/go-chi/chi" + "net/http" + + "github.com/titpetric/factory/resputil" + + "github.com/crusttech/crust/auth/rest/request" +) + +// Internal API interface +type AuthAPI interface { + Login(context.Context, *request.AuthLogin) (interface{}, error) + Create(context.Context, *request.AuthCreate) (interface{}, error) +} + +// HTTP API interface +type Auth struct { + Login func(http.ResponseWriter, *http.Request) + Create func(http.ResponseWriter, *http.Request) +} + +func NewAuth(ah AuthAPI) *Auth { + return &Auth{ + Login: func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + params := request.NewAuthLogin() + resputil.JSON(w, params.Fill(r), func() (interface{}, error) { + return ah.Login(r.Context(), params) + }) + }, + Create: func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + params := request.NewAuthCreate() + resputil.JSON(w, params.Fill(r), func() (interface{}, error) { + return ah.Create(r.Context(), params) + }) + }, + } +} + +func (ah *Auth) MountRoutes(r chi.Router, middlewares ...func(http.Handler) http.Handler) { + r.Group(func(r chi.Router) { + r.Use(middlewares...) + r.Route("/auth", func(r chi.Router) { + r.Post("/login", ah.Login) + r.Post("/create", ah.Create) + }) + }) +} diff --git a/auth/rest/request/auth.go b/auth/rest/request/auth.go new file mode 100644 index 000000000..f4cb67144 --- /dev/null +++ b/auth/rest/request/auth.go @@ -0,0 +1,139 @@ +package request + +/* + Hello! This file is auto-generated from `docs/src/spec.json`. + + For development: + In order to update the generated files, edit this file under the location, + add your struct fields, imports, API definitions and whatever you want, and: + + 1. run [spec](https://github.com/titpetric/spec) in the same folder, + 2. run `./_gen.php` in this folder. + + You may edit `auth.go`, `auth.util.go` or `auth_test.go` to + implement your API calls, helper functions and tests. The file `auth.go` + is only generated the first time, and will not be overwritten if it exists. +*/ + +import ( + "encoding/json" + "github.com/go-chi/chi" + "github.com/jmoiron/sqlx/types" + "github.com/pkg/errors" + "io" + "net/http" + "strings" +) + +var _ = chi.URLParam +var _ = types.JSONText{} + +// Auth login request parameters +type AuthLogin struct { + Username string + Password string +} + +func NewAuthLogin() *AuthLogin { + return &AuthLogin{} +} + +func (a *AuthLogin) Fill(r *http.Request) error { + var err error + + if strings.ToLower(r.Header.Get("content-type")) == "application/json" { + err = json.NewDecoder(r.Body).Decode(a) + + switch { + case err == io.EOF: + err = nil + case err != nil: + return errors.Wrap(err, "error parsing http request body") + } + } + + r.ParseForm() + get := map[string]string{} + post := map[string]string{} + urlQuery := r.URL.Query() + for name, param := range urlQuery { + get[name] = string(param[0]) + } + postVars := r.Form + for name, param := range postVars { + post[name] = string(param[0]) + } + + if val, ok := post["username"]; ok { + + a.Username = val + } + if val, ok := post["password"]; ok { + + a.Password = val + } + + return err +} + +var _ RequestFiller = NewAuthLogin() + +// Auth create request parameters +type AuthCreate struct { + Name string + Email string + Username string + Password string +} + +func NewAuthCreate() *AuthCreate { + return &AuthCreate{} +} + +func (a *AuthCreate) Fill(r *http.Request) error { + var err error + + if strings.ToLower(r.Header.Get("content-type")) == "application/json" { + err = json.NewDecoder(r.Body).Decode(a) + + switch { + case err == io.EOF: + err = nil + case err != nil: + return errors.Wrap(err, "error parsing http request body") + } + } + + r.ParseForm() + get := map[string]string{} + post := map[string]string{} + urlQuery := r.URL.Query() + for name, param := range urlQuery { + get[name] = string(param[0]) + } + postVars := r.Form + for name, param := range postVars { + post[name] = string(param[0]) + } + + if val, ok := post["name"]; ok { + + a.Name = val + } + if val, ok := post["email"]; ok { + + a.Email = val + } + if val, ok := post["username"]; ok { + + a.Username = val + } + if val, ok := post["password"]; ok { + + a.Password = val + } + + return err +} + +var _ RequestFiller = NewAuthCreate() diff --git a/auth/rest/request/misc.go b/auth/rest/request/misc.go new file mode 100644 index 000000000..9992cfd73 --- /dev/null +++ b/auth/rest/request/misc.go @@ -0,0 +1,10 @@ +package request + +import ( + "net/http" +) + +// RequestFiller is an interface for typed request parameters +type RequestFiller interface { + Fill(r *http.Request) error +} diff --git a/auth/rest/router.go b/auth/rest/router.go new file mode 100644 index 000000000..b11822a4f --- /dev/null +++ b/auth/rest/router.go @@ -0,0 +1,17 @@ +package rest + +import ( + "github.com/crusttech/crust/auth/rest/handlers" + "github.com/crusttech/crust/auth/service" + "github.com/crusttech/crust/auth/types" + "github.com/go-chi/chi" +) + +func MountRoutes(jwtAuth types.TokenEncoder) func(chi.Router) { + var userSvc = service.User() + + // Initialize handers & controllers. + return func(r chi.Router) { + handlers.NewAuth(Auth{}.New(userSvc, jwtAuth)).MountRoutes(r) + } +} diff --git a/auth/routes.go b/auth/routes.go new file mode 100644 index 000000000..7ae92ae51 --- /dev/null +++ b/auth/routes.go @@ -0,0 +1,48 @@ +package auth + +import ( + "fmt" + "reflect" + "runtime" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" +) + +func mountRoutes(r chi.Router, opts *configuration, mounts ...func(r chi.Router)) { + if opts.http.logging { + r.Use(middleware.Logger) + } + if opts.http.metrics { + r.Use(metrics{}.Middleware("crm")) + } + + for _, mount := range mounts { + mount(r) + } +} + +func mountSystemRoutes(r chi.Router, opts *configuration) { + if opts.http.metrics { + r.Handle("/metrics", metrics{}.Handler()) + } + r.Mount("/debug", middleware.Profiler()) +} + +func printRoutes(r chi.Router, opts *configuration) { + var printRoutes func(chi.Routes, string, string) + printRoutes = func(r chi.Routes, indent string, prefix string) { + routes := r.Routes() + for _, route := range routes { + if route.SubRoutes != nil && len(route.SubRoutes.Routes()) > 0 { + fmt.Printf(indent+"%s - with %d handlers, %d subroutes\n", route.Pattern, len(route.Handlers), len(route.SubRoutes.Routes())) + printRoutes(route.SubRoutes, indent+"\t", prefix+route.Pattern[:len(route.Pattern)-2]) + } else { + for key, fn := range route.Handlers { + fmt.Printf("%s%s\t%s -> %s\n", indent, key, prefix+route.Pattern, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) + } + } + } + } + printRoutes(r, "", "") +} diff --git a/auth/service/error.go b/auth/service/error.go new file mode 100644 index 000000000..f2180634b --- /dev/null +++ b/auth/service/error.go @@ -0,0 +1,9 @@ +package service + +type ( + serviceError string +) + +func (e serviceError) Error() string { + return "crust.sam.service." + string(e) +} diff --git a/auth/service/user.go b/auth/service/user.go new file mode 100644 index 000000000..f23aa3b08 --- /dev/null +++ b/auth/service/user.go @@ -0,0 +1,98 @@ +package service + +import ( + "context" + "github.com/crusttech/crust/auth/repository" + "github.com/crusttech/crust/auth/types" +) + +const ( + ErrUserInvalidCredentials = serviceError("UserInvalidCredentials") + ErrUserLocked = serviceError("UserLocked") +) + +type ( + user struct { + repository repository.User + } + + UserService interface { + With(ctx context.Context) UserService + + Create(input *types.User) (*types.User, error) + ValidateCredentials(username, password string) (*types.User, error) + } +) + +func User() UserService { + return &user{ + repository.NewUser(context.Background()), + } +} + +func (svc *user) With(ctx context.Context) UserService { + return &user{ + svc.repository.With(ctx), + } +} + +func (svc *user) ValidateCredentials(username, password string) (*types.User, error) { + user, err := svc.repository.FindUserByUsername(username) + if err != nil { + return nil, err + } + + if !user.ValidatePassword(password) { + return nil, ErrUserInvalidCredentials + } + + if !svc.canLogin(user) { + return nil, ErrUserLocked + } + + return user, nil +} + +func (svc *user) FindByID(id uint64) (*types.User, error) { + return svc.repository.FindUserByID(id) +} + +func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) { + return svc.repository.FindUsers(filter) +} + +func (svc *user) Create(input *types.User) (user *types.User, err error) { + return user, svc.repository.DB().Transaction(func() error { + // Encrypt user password + if user, err = svc.repository.CreateUser(input); err != nil { + return err + } + return nil + }) +} + +func (svc *user) Update(mod *types.User) (*types.User, error) { + return svc.repository.UpdateUser(mod) +} + +func (svc *user) canLogin(u *types.User) bool { + return u != nil && u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil +} + +func (svc *user) Delete(id uint64) error { + // @todo: permissions check if current user can delete this user + // @todo: notify users that user has been removed (remove from web UI) + return svc.repository.DeleteUserByID(id) +} + +func (svc *user) Suspend(id uint64) error { + // @todo: permissions check if current user can suspend this user + // @todo: notify users that user has been supsended (remove from web UI) + return svc.repository.SuspendUserByID(id) +} + +func (svc *user) Unsuspend(id uint64) error { + // @todo: permissions check if current user can unsuspend this user + // @todo: notify users that user has been unsuspended + return svc.repository.UnsuspendUserByID(id) +} diff --git a/auth/start.go b/auth/start.go new file mode 100644 index 000000000..0b8ea7adb --- /dev/null +++ b/auth/start.go @@ -0,0 +1,95 @@ +package auth + +import ( + "fmt" + "log" + "net" + "net/http" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + + "github.com/SentimensRG/ctx/sigctx" + + "github.com/crusttech/crust/auth/rest" + + "github.com/go-chi/cors" + "github.com/titpetric/factory" + "github.com/titpetric/factory/resputil" +) + +func Init() error { + // validate configuration + if err := config.validate(); err != nil { + return err + } + + // start/configure database connection + factory.Database.Add("default", config.db.dsn) + db, err := factory.Database.Get() + if err != nil { + return err + } + // @todo: profiling as an external service? + switch config.db.profiler { + case "stdout": + db.Profiler = &factory.Database.ProfilerStdout + default: + fmt.Println("No database query profiler selected") + } + + // configure resputil options + resputil.SetConfig(resputil.Options{ + Pretty: config.http.pretty, + Trace: config.http.tracing, + Logger: func(err error) { + // @todo: error logging + }, + }) + + return nil +} + +func Start() error { + var ctx = sigctx.New() + + log.Println("Starting http server on address " + config.http.addr) + listener, err := net.Listen("tcp", config.http.addr) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Can't listen on addr %s", config.http.addr)) + } + + // JWT Auth + jwtAuth, err := JWT() + if err != nil { + return errors.Wrap(err, "Error creating JWT Auth object") + } + + r := chi.NewRouter() + r.Use(handleCORS) + + // Only protect application routes with JWT + r.Group(func(r chi.Router) { + r.Use(jwtAuth.Verifier(), jwtAuth.Authenticator()) + mountRoutes(r, config, rest.MountRoutes(jwtAuth)) + }) + + printRoutes(r, config) + mountSystemRoutes(r, config) + + go http.Serve(listener, r) + <-ctx.Done() + + return nil +} + +// Sets up default CORS rules to use as a middleware +func handleCORS(next http.Handler) http.Handler { + return cors.New(cors.Options{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + AllowCredentials: true, + MaxAge: 300, // Maximum value not ignored by any of major browsers + }).Handler(next) +} diff --git a/auth/interfaces.go b/auth/types/interfaces.go similarity index 91% rename from auth/interfaces.go rename to auth/types/interfaces.go index ad10152cc..61d51d57e 100644 --- a/auth/interfaces.go +++ b/auth/types/interfaces.go @@ -1,4 +1,4 @@ -package auth +package types type ( Identifiable interface { diff --git a/auth/types/user.go b/auth/types/user.go new file mode 100644 index 000000000..872b11ba7 --- /dev/null +++ b/auth/types/user.go @@ -0,0 +1,48 @@ +package types + +import ( + "encoding/json" + "golang.org/x/crypto/bcrypt" + "time" +) + +type ( + User struct { + ID uint64 `json:"id" db:"id"` + Username string `json:"username" db:"username"` + Meta json.RawMessage `json:"-" db:"meta"` + OrganisationID uint64 `json:"organisationId" db:"rel_organisation"` + Password []byte `json:"-" db:"password"` + CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` + UpdatedAt *time.Time `json:"updatedAt,omitempty" db:"updated_at"` + SuspendedAt *time.Time `json:"suspendedAt,omitempty" db:"suspended_at"` + DeletedAt *time.Time `json:"deletedAt,omitempty" db:"deleted_at"` + } + + UserFilter struct { + Query string + MembersOfChannel uint64 + } +) + +func (u *User) Valid() bool { + return u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil +} + +func (u *User) Identity() uint64 { + return u.ID +} + +func (u *User) ValidatePassword(password string) bool { + return bcrypt.CompareHashAndPassword(u.Password, []byte(password)) == nil +} + +func (u *User) GeneratePassword(password string) error { + pwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + + u.Password = pwd + return nil +} diff --git a/cmd/auth/dump.go b/cmd/auth/dump.go new file mode 100644 index 000000000..4c2c161a1 --- /dev/null +++ b/cmd/auth/dump.go @@ -0,0 +1,9 @@ +package main + +// this file exists to keep go-spew in vendor for development needs + +import ( + "github.com/davecgh/go-spew/spew" +) + +var _ = spew.Dump diff --git a/cmd/auth/flags.go b/cmd/auth/flags.go new file mode 100644 index 000000000..b705f48c2 --- /dev/null +++ b/cmd/auth/flags.go @@ -0,0 +1,23 @@ +package main + +import ( + _ "github.com/joho/godotenv/autoload" + "github.com/namsral/flag" +) + +type configuration struct { + monitorInterval int +} + +func flags(prefix string, mountFlags ...func(...string)) configuration { + var config configuration + + flag.IntVar(&config.monitorInterval, "monitor-interval", 300, "Monitor interval (seconds, 0 = disable)") + + for _, mount := range mountFlags { + mount(prefix) + } + + flag.Parse() + return config +} diff --git a/cmd/auth/main.go b/cmd/auth/main.go new file mode 100644 index 000000000..b65f2fecf --- /dev/null +++ b/cmd/auth/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "log" + "os" + + "github.com/crusttech/crust/auth" + "github.com/crusttech/crust/rbac" +) + +func main() { + config := flags("auth", rbac.Flags, auth.Flags) + + // log to stdout not stderr + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + go NewMonitor(config.monitorInterval) + + if err := auth.Init(); err != nil { + log.Fatalf("Error initializing auth: %+v", err) + } + if err := auth.Start(); err != nil { + log.Fatalf("Error starting/running auth: %+v", err) + } +} diff --git a/cmd/auth/monitor.go b/cmd/auth/monitor.go new file mode 100644 index 000000000..e0e382862 --- /dev/null +++ b/cmd/auth/monitor.go @@ -0,0 +1,59 @@ +package main + +import ( + "encoding/json" + "expvar" + "fmt" + "runtime" + "time" +) + +type Monitor struct { + Alloc, + TotalAlloc, + Sys, + Mallocs, + Frees, + LiveObjects, + PauseTotalNs uint64 + + NumGC uint32 + NumGoroutine int +} + +func NewMonitor(duration int) { + var ( + m = Monitor{} + rtm runtime.MemStats + goroutines = expvar.NewInt("num_goroutine") + ) + var interval = time.Duration(duration) * time.Second + for { + <-time.After(interval) + + // Read full mem stats + runtime.ReadMemStats(&rtm) + + // Number of goroutines + m.NumGoroutine = runtime.NumGoroutine() + goroutines.Set(int64(m.NumGoroutine)) + + // Misc memory stats + m.Alloc = rtm.Alloc + m.TotalAlloc = rtm.TotalAlloc + m.Sys = rtm.Sys + m.Mallocs = rtm.Mallocs + m.Frees = rtm.Frees + + // Live objects = Mallocs - Frees + m.LiveObjects = m.Mallocs - m.Frees + + // GC Stats + m.PauseTotalNs = rtm.PauseTotalNs + m.NumGC = rtm.NumGC + + // Just encode to json and print + b, _ := json.Marshal(m) + fmt.Println(string(b)) + } +} diff --git a/codegen/auth/rest/handlers/index.php b/codegen/auth/rest/handlers/index.php new file mode 100644 index 000000000..7aa5983f3 --- /dev/null +++ b/codegen/auth/rest/handlers/index.php @@ -0,0 +1,36 @@ + function($name, $api) { + return strtolower($name) . ".go"; + }, +); + +foreach ($templates as $template => $fn) +foreach ($apis as $api) { + $name = ucfirst($api['interface']); + $filename = $dirname . "/" . $fn($name, $api); + + $tpl->load($template); + $tpl->assign($common); + $tpl->assign("package", basename(__DIR__)); + $tpl->assign("name", $name); + $tpl->assign("api", $api); + $tpl->assign("apis", $apis); + $tpl->assign("self", strtolower(substr($name, 0, 1))); + $tpl->assign("structs", $api['struct']); + $imports = array(); + if (is_array($api['struct'])) + foreach ($api['struct'] as $struct) { + if (isset($struct['imports'])) + foreach ($struct['imports'] as $import) { + $imports[] = $import; + } + } + $tpl->assign("imports", $imports); + $tpl->assign("calls", $api['apis']); + $contents = str_replace("\n\n}", "\n}", $tpl->get()); + + file_put_contents($filename, $contents); + echo $filename . "\n"; +} diff --git a/codegen/auth/rest/index.php b/codegen/auth/rest/index.php new file mode 100644 index 000000000..fc5eabcd9 --- /dev/null +++ b/codegen/auth/rest/index.php @@ -0,0 +1,30 @@ +load("http_.tpl"); + $tpl->assign($common); + $tpl->assign("package", basename(__DIR__)); + $tpl->assign("name", $name); + $tpl->assign("api", $api); + $tpl->assign("self", strtolower(substr($name, 0, 1))); + $tpl->assign("structs", $api['struct']); + $imports = array(); + if (is_array($api['struct'])) + foreach ($api['struct'] as $struct) { + if (isset($struct['imports'])) + foreach ($struct['imports'] as $import) { + $imports[] = $import; + } + } + $tpl->assign("imports", $imports); + $tpl->assign("calls", $api['apis']); + $contents = str_replace("\n\n}", "\n}", $tpl->get()); + + if (!file_exists($filename)) { + file_put_contents($filename, $contents); + echo $filename . "\n"; + } +} diff --git a/codegen/auth/rest/request/index.php b/codegen/auth/rest/request/index.php new file mode 100644 index 000000000..c98c0349b --- /dev/null +++ b/codegen/auth/rest/request/index.php @@ -0,0 +1,36 @@ + function($name, $api) { + return strtolower($name) . ".go"; + }, +); + +foreach ($templates as $template => $fn) +foreach ($apis as $api) { + $name = ucfirst($api['interface']); + $filename = $dirname . "/" . $fn($name, $api); + + $tpl->load($template); + $tpl->assign($common); + $tpl->assign("package", basename(__DIR__)); + $tpl->assign("name", $name); + $tpl->assign("api", $api); + $tpl->assign("apis", $apis); + $tpl->assign("self", strtolower(substr($name, 0, 1))); + $tpl->assign("structs", $api['struct']); + $imports = array(); + if (is_array($api['struct'])) + foreach ($api['struct'] as $struct) { + if (isset($struct['imports'])) + foreach ($struct['imports'] as $import) { + $imports[] = $import; + } + } + $tpl->assign("imports", $imports); + $tpl->assign("calls", $api['apis']); + $contents = str_replace("\n\n}", "\n}", $tpl->get()); + + file_put_contents($filename, $contents); + echo $filename . "\n"; +} diff --git a/codegen/auth/types/index.php b/codegen/auth/types/index.php new file mode 100644 index 000000000..0f41ab172 --- /dev/null +++ b/codegen/auth/types/index.php @@ -0,0 +1,37 @@ + function($name, $api) { + return strtolower($name) . ".go"; + }, +); + +foreach ($templates as $template => $fn) +foreach ($apis as $api) { + if (is_array($api['struct'])) { + $name = ucfirst($api['interface']); + $filename = $dirname . "/" . $fn($name, $api); + + $tpl->load($template); + $tpl->assign($common); + $tpl->assign("package", basename(__DIR__)); + $tpl->assign("name", $name); + $tpl->assign("api", $api); + $tpl->assign("apis", $apis); + $tpl->assign("self", strtolower(substr($name, 0, 1))); + $tpl->assign("structs", $api['struct']); + $imports = array(); + foreach ($api['struct'] as $struct) { + if (isset($struct['imports'])) + foreach ($struct['imports'] as $import) { + $imports[] = $import; + } + } + $tpl->assign("imports", $imports); + $tpl->assign("calls", $api['apis']); + $contents = str_replace("\n\n}", "\n}", $tpl->get()); + + file_put_contents($filename, $contents); + echo $filename . "\n"; + } +}