3
0

POC OIDC implementation

This commit is contained in:
Denis Arh 2018-09-12 03:34:26 +02:00
parent 6e039e1782
commit cdd72c8b90
14 changed files with 372 additions and 127 deletions

View File

@ -10,6 +10,8 @@ type (
appFlags struct {
http *config.HTTP
db *config.Database
jwt *config.JWT
oidc *config.OIDC
}
)
@ -25,6 +27,12 @@ func (c *appFlags) Validate() error {
if err := c.db.Validate(); err != nil {
return err
}
if err := c.jwt.Validate(); err != nil {
return err
}
if err := c.oidc.Validate(); err != nil {
return err
}
return nil
}
@ -38,5 +46,7 @@ func Flags(prefix ...string) {
flags = &appFlags{
new(config.HTTP).Init(prefix...),
new(config.Database).Init(prefix...),
new(config.JWT).Init(prefix...),
new(config.OIDC).Init(prefix...),
}
}

View File

@ -17,6 +17,7 @@ type (
With(context.Context) User
FindUserByEmail(email string) (*types.User, error)
FindUserByUsername(username string) (*types.User, error)
FindUserByID(id uint64) (*types.User, error)
FindUsers(filter *types.UserFilter) ([]*types.User, error)
@ -52,6 +53,13 @@ func (r *user) FindUserByUsername(username string) (*types.User, error) {
return mod, isFound(r.db().Get(mod, sql, username), mod.ID > 0, ErrUserNotFound)
}
func (r *user) FindUserByEmail(email string) (*types.User, error) {
sql := "SELECT * FROM users WHERE email = ? AND " + sqlUserScope
mod := &types.User{}
return mod, isFound(r.db().Get(mod, sql, email), mod.ID > 0, ErrUserNotFound)
}
func (r *user) FindUserByID(id uint64) (*types.User, error) {
sql := "SELECT * FROM users WHERE id = ? AND " + sqlUserScope
mod := &types.User{}

172
auth/rest/oidc.go Normal file
View File

@ -0,0 +1,172 @@
package rest
import (
"context"
"github.com/coreos/go-oidc"
"github.com/crusttech/crust/auth/service"
"github.com/crusttech/crust/auth/types"
"github.com/crusttech/crust/config"
"github.com/titpetric/factory/resputil"
"golang.org/x/oauth2"
"math/rand"
"net/http"
"strconv"
"time"
)
type (
openIdConnect struct {
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
config oauth2.Config
appURL string
stateCookieExpiry int64
userService service.UserService
jwt jwtEncodeCookieSetter
}
jwtEncodeCookieSetter interface {
types.TokenEncoder
SetToCookie(w http.ResponseWriter, r *http.Request, identity types.Identifiable)
}
)
const openIdConnectStateCookie = "oidc-state"
func OpenIdConnect(cfg *config.OIDC, usvc service.UserService, jwt jwtEncodeCookieSetter) (c *openIdConnect, err error) {
c = &openIdConnect{
appURL: cfg.AppURL,
stateCookieExpiry: cfg.StateCookieExpiry,
userService: usvc,
jwt: jwt,
}
// Allow 5 seconds for issuer discovery process
c.provider, err = oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, err
}
// Configure an OpenID Connect aware OAuth2 client.
c.config = oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
RedirectURL: cfg.RedirectURL,
// Discovery returns the OAuth2 endpoints.
Endpoint: c.provider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
c.verifier = c.provider.Verifier(&oidc.Config{ClientID: cfg.ClientID})
return
}
// Redirects user to the issuer's login screen
func (c *openIdConnect) HandleRedirect(w http.ResponseWriter, r *http.Request) {
// @todo sure we can improve this...
rand.Seed(4321)
var state = strconv.FormatInt(rand.Int63(), 10)
// Store state to cookie as well
c.setStateCookie(w, r, state)
http.Redirect(w, r, c.config.AuthCodeURL(state), http.StatusFound)
}
// Handles callback from issuer
//
// If everything goes well (scope & token verification) it reads issued claims,
// creates Crust JWT and stores it in a cookie.
//
// @todo All failed responses must redirect to appURL as well + some error code that will be displayed on the client
func (c *openIdConnect) HandleOAuth2Callback(w http.ResponseWriter, r *http.Request) {
var ctx = r.Context()
if !c.stateCheck(r) {
resputil.JSON(w, "State check failed")
return
}
c.setStateCookie(w, r, "") // remove state cookie
oauth2Token, err := c.config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
resputil.JSON(w, err)
return
}
// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
resputil.JSON(w, err)
return
}
// Parse and verify ID Token payload.
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
resputil.JSON(w, err)
return
}
// Extract custom claims
var claims struct {
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
resputil.JSON(w, err)
return
}
var user *types.User
if user, err = c.userService.FindOrCreate(claims.Email); err != nil {
resputil.JSON(w, err)
return
} else {
c.jwt.SetToCookie(w, r, user)
}
http.Redirect(w, r, c.appURL+"?jwt="+c.jwt.Encode(user), http.StatusSeeOther)
}
func (c *openIdConnect) stateCheck(r *http.Request) bool {
if cState, err := r.Cookie(openIdConnectStateCookie); err == nil {
rState := r.URL.Query().Get("state")
return len(rState) > 0 && cState.Value == rState
}
return false
}
// Sets state cookie
func (c *openIdConnect) setStateCookie(w http.ResponseWriter, r *http.Request, value string) {
var maxAge int
if len(value) == 0 {
// When empty string for a value is received,
// set maxAge to -1. That will effectively delete the cookie
maxAge = -1
}
// Store state to cookie as well
http.SetCookie(w, &http.Cookie{
Name: openIdConnectStateCookie,
Value: value,
Expires: time.Now().Add(time.Duration(c.stateCookieExpiry) * time.Minute),
MaxAge: maxAge,
HttpOnly: true,
Secure: r.URL.Scheme == "https",
Path: "/oidc",
})
}

View File

@ -1,18 +1,42 @@
package rest
import (
"log"
"net/http"
"github.com/go-chi/chi"
"github.com/titpetric/factory/resputil"
"github.com/crusttech/crust/auth/rest/handlers"
"github.com/crusttech/crust/auth/service"
"github.com/crusttech/crust/internal/auth"
"github.com/crusttech/crust/config"
)
func MountRoutes(jwtAuth auth.TokenEncoder) func(chi.Router) {
func MountRoutes(oidcConfig *config.OIDC, jwtAuth jwtEncodeCookieSetter) func(chi.Router) {
var userSvc = service.User()
oidc, err := OpenIdConnect(oidcConfig, userSvc, jwtAuth)
if err != nil {
log.Errorf("Could not initialize OIDC:", err.Error())
}
// Initialize handers & controllers.
return func(r chi.Router) {
handlers.NewAuth(Auth{}.New(userSvc, jwtAuth)).MountRoutes(r)
if oidc != nil {
r.Route("/oidc", func(r chi.Router) {
r.Get("/", oidc.HandleRedirect)
r.Get("/callback", oidc.HandleOAuth2Callback)
})
}
r.Get("/jwt", func(w http.ResponseWriter, r *http.Request) {
if c, err := r.Cookie("jwt"); err != nil {
resputil.JSON(w, "")
} else {
resputil.JSON(w, c.Value)
}
})
}
}

View File

@ -25,10 +25,8 @@ type (
Create(input *types.User) (*types.User, error)
Update(mod *types.User) (*types.User, error)
Delete(id uint64) error
Suspend(id uint64) error
Unsuspend(id uint64) error
FindOrCreate(email string) (*types.User, error)
ValidateCredentials(username, password string) (*types.User, error)
}
)
@ -70,6 +68,19 @@ func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) {
return svc.repository.FindUsers(filter)
}
// Finds if user with a specific email exists and returns it otherwise it creates a fresh one
func (svc *user) FindOrCreate(email string) (user *types.User, err error) {
return user, svc.repository.DB().Transaction(func() error {
if user, err = svc.repository.FindUserByEmail(email); err != repository.ErrUserNotFound {
return err
} else if user, err = svc.repository.CreateUser(&types.User{Email: email}); err != nil {
return err
}
return nil
})
}
func (svc *user) Create(input *types.User) (user *types.User, err error) {
return user, svc.repository.DB().Transaction(func() error {
// Encrypt user password
@ -87,21 +98,3 @@ func (svc *user) Update(mod *types.User) (*types.User, error) {
func (svc *user) canLogin(u *types.User) bool {
return u != nil && u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil
}
func (svc *user) Delete(id uint64) error {
// @todo: permissions check if current user can delete this user
// @todo: notify users that user has been removed (remove from web UI)
return svc.repository.DeleteUserByID(id)
}
func (svc *user) Suspend(id uint64) error {
// @todo: permissions check if current user can suspend this user
// @todo: notify users that user has been supsended (remove from web UI)
return svc.repository.SuspendUserByID(id)
}
func (svc *user) Unsuspend(id uint64) error {
// @todo: permissions check if current user can unsuspend this user
// @todo: notify users that user has been unsuspended
return svc.repository.UnsuspendUserByID(id)
}

View File

@ -7,13 +7,13 @@ import (
"net/http"
"github.com/SentimensRG/ctx/sigctx"
"github.com/crusttech/crust/auth/rest"
"github.com/go-chi/chi"
"github.com/go-chi/cors"
"github.com/pkg/errors"
"github.com/titpetric/factory"
"github.com/titpetric/factory/resputil"
"github.com/crusttech/crust/auth/rest"
"github.com/crusttech/crust/internal/auth"
)
@ -70,7 +70,7 @@ func Start() error {
// Only protect application routes with JWT
r.Group(func(r chi.Router) {
r.Use(jwtAuth.Verifier(), jwtAuth.Authenticator())
mountRoutes(r, flags.http, rest.MountRoutes(jwtAuth))
mountRoutes(r, flags.http, rest.MountRoutes(flags.oidc, jwtAuth))
})
printRoutes(r, flags.http)

View File

@ -10,6 +10,7 @@ type (
User struct {
ID uint64 `json:"id" db:"id"`
Username string `json:"username" db:"username"`
Email string `json:"email" db:"email"`
Meta json.RawMessage `json:"-" db:"meta"`
OrganisationID uint64 `json:"organisationId" db:"rel_organisation"`
Password []byte `json:"-" db:"password"`

60
config/oidc.go Normal file
View File

@ -0,0 +1,60 @@
package config
import (
"github.com/namsral/flag"
"github.com/pkg/errors"
)
type (
OIDC struct {
Issuer string
ClientID string
ClientSecret string
RedirectURL string
AppURL string
StateCookieExpiry int64
}
)
var oidc *OIDC
func (c *OIDC) Validate() error {
if c == nil {
return nil
}
if c.Issuer == "" {
return errors.New("OIDC Issuer not set for AUTH")
}
if c.ClientID == "" {
return errors.New("OIDC ClientID not set for AUTH")
}
if c.ClientSecret == "" {
return errors.New("OIDC ClientSecret not set for AUTH")
}
if c.RedirectURL == "" {
return errors.New("OIDC RedirectURL not set for AUTH")
}
if c.AppURL == "" {
return errors.New("OIDC AppURL not set for AUTH")
}
return nil
}
func (*OIDC) Init(prefix ...string) *OIDC {
if oidc != nil {
return oidc
}
oidc := new(OIDC)
flag.StringVar(&oidc.Issuer, "auth-oidc-issuer", "", "OIDC Issuer")
flag.StringVar(&oidc.ClientID, "auth-oidc-client-id", "", "OIDC Client ID")
flag.StringVar(&oidc.ClientSecret, "auth-oidc-client-secret", "", "OIDC Client Secret")
flag.StringVar(&oidc.RedirectURL, "auth-oidc-redirect-url", "", "OIDC RedirectURL")
flag.StringVar(&oidc.AppURL, "auth-oidc-app-url", "", "OIDC AppURL")
flag.Int64Var(&oidc.StateCookieExpiry, "auth-oidc-state-cookie-expiry", 15, "OIDC State cookie expiry in minutes")
return oidc
}

View File

@ -11,7 +11,9 @@ import (
)
type jwt struct {
tokenAuth *jwtauth.JWTAuth
expiry int64
cookieDomain string
tokenAuth *jwtauth.JWTAuth
}
func JWT() (*jwt, error) {
@ -19,7 +21,11 @@ func JWT() (*jwt, error) {
return nil, err
}
jwt := &jwt{tokenAuth: jwtauth.New("HS256", []byte(flags.jwt.Secret), nil)}
jwt := &jwt{
expiry: flags.jwt.Expiry,
cookieDomain: flags.jwt.CookieDomain,
tokenAuth: jwtauth.New("HS256", []byte(flags.jwt.Secret), nil),
}
if flags.jwt.DebugToken {
log.Println("DEBUG JWT TOKEN:", jwt.Encode(NewIdentity(1)))
@ -34,10 +40,9 @@ func (t *jwt) Verifier() func(http.Handler) http.Handler {
}
func (t *jwt) Encode(identity Identifiable) string {
// @todo Set expiry
claims := jwtauth.Claims{}
claims.Set("sub", strconv.FormatUint(identity.Identity(), 10))
claims.SetExpiryIn(time.Duration(flags.jwt.Expiry) * time.Minute)
claims.SetExpiryIn(time.Duration(t.expiry) * time.Minute)
_, jwt, _ := t.tokenAuth.Encode(claims)
return jwt
@ -62,3 +67,18 @@ func (t *jwt) Authenticator() func(http.Handler) http.Handler {
})
}
}
// Extracts and authenticates JWT from context, validates claims
func (t *jwt) SetToCookie(w http.ResponseWriter, r *http.Request, identity Identifiable) {
// Store state to cookie as well
http.SetCookie(w, &http.Cookie{
Name: "jwt",
Value: t.Encode(identity),
Expires: time.Now().Add(time.Duration(t.expiry) * time.Minute),
Domain: t.cookieDomain,
Secure: r.URL.Scheme == "https",
Path: "/",
})
}

View File

@ -7,9 +7,10 @@ import (
type (
JWT struct {
Secret string
Expiry int64
DebugToken bool
Secret string
Expiry int64
DebugToken bool
CookieDomain string
}
)
@ -33,6 +34,7 @@ func (*JWT) Init(prefix ...string) *JWT {
jwt := new(JWT)
flag.StringVar(&jwt.Secret, "auth-jwt-secret", "", "JWT Secret")
flag.Int64Var(&jwt.Expiry, "auth-jwt-expiry", 3600, "JWT Expiration in minutes")
flag.StringVar(&jwt.CookieDomain, "auth-jwt-cookie-domain", "", "JWT Cookie domain")
flag.BoolVar(&jwt.DebugToken, "auth-jwt-debug", false, "Generate debug JWT key")
return jwt
}

View File

@ -52,6 +52,7 @@ CREATE TABLE channels (
-- changes are stored in audit log
CREATE TABLE users (
id BIGINT UNSIGNED NOT NULL,
email TEXT NOT NULL,
username TEXT NOT NULL,
password TEXT,
meta JSON NOT NULL,

View File

@ -1,85 +0,0 @@
package sam
import (
"context"
"github.com/coreos/go-oidc"
"github.com/davecgh/go-spew/spew"
"golang.org/x/oauth2"
"net/http"
)
type (
openIdConnect struct {
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
config oauth2.Config
}
)
func OpenIdConnect(ctx context.Context, issuer string, cfg oauth2.Config) (c *openIdConnect, err error) {
c = &openIdConnect{}
c.provider, err = oidc.NewProvider(ctx, issuer)
if err != nil {
return nil, err
}
// Configure an OpenID Connect aware OAuth2 client.
c.config = oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
RedirectURL: cfg.RedirectURL,
// Discovery returns the OAuth2 endpoints.
Endpoint: c.provider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
c.verifier = c.provider.Verifier(&oidc.Config{ClientID: cfg.ClientID})
return
}
func (c *openIdConnect) HandleRedirect(w http.ResponseWriter, r *http.Request) {
state := "@todo"
http.Redirect(w, r, c.config.AuthCodeURL(state), http.StatusFound)
}
func (c *openIdConnect) HandleOAuth2Callback(w http.ResponseWriter, r *http.Request) {
var ctx = r.Context()
// @todo check state
// Verify state and errors.
oauth2Token, err := c.config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
// handle error
}
// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
// handle missing token
}
// Parse and verify ID Token payload.
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
// handle error
}
// Extract custom claims
var claims struct {
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
// handle error
}
spew.Dump()
}

View File

@ -13,7 +13,6 @@ import (
"github.com/pkg/errors"
"github.com/titpetric/factory"
"github.com/titpetric/factory/resputil"
"golang.org/x/oauth2"
authService "github.com/crusttech/crust/auth/service"
"github.com/crusttech/crust/internal/auth"
@ -75,15 +74,6 @@ func Start() error {
r := chi.NewRouter()
r.Use(handleCORS)
if oidc, err := OpenIdConnect(ctx, "https://accounts.google.com", oauth2.Config{}); err != nil {
return errors.Wrap(err, "Could not initialize OIDC")
} else {
r.Route("/oidc/satosa", func(r chi.Router) {
r.Get("/", oidc.HandleRedirect)
r.Get("/callback", oidc.HandleOAuth2Callback)
})
}
// Only protect application routes with JWT
r.Group(func(r chi.Router) {
r.Use(jwtAuth.Verifier(), jwtAuth.Authenticator())

49
sam/types/user.go Normal file
View File

@ -0,0 +1,49 @@
package types
import (
"encoding/json"
"golang.org/x/crypto/bcrypt"
"time"
)
type (
User struct {
ID uint64 `json:"id" db:"id"`
Email string `json:"email" db:"email"`
Username string `json:"username" db:"username"`
Meta json.RawMessage `json:"-" db:"meta"`
OrganisationID uint64 `json:"organisationId" db:"rel_organisation"`
Password []byte `json:"-" db:"password"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updatedAt,omitempty" db:"updated_at"`
SuspendedAt *time.Time `json:"suspendedAt,omitempty" db:"suspended_at"`
DeletedAt *time.Time `json:"deletedAt,omitempty" db:"deleted_at"`
}
UserFilter struct {
Query string
MembersOfChannel uint64
}
)
func (u *User) Valid() bool {
return u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil
}
func (u *User) Identity() uint64 {
return u.ID
}
func (u *User) ValidatePassword(password string) bool {
return bcrypt.CompareHashAndPassword(u.Password, []byte(password)) == nil
}
func (u *User) GeneratePassword(password string) error {
pwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.Password = pwd
return nil
}