3
0

Replacing dgrijalva/jwt-go with lestrrat-go/jwx

This commit is contained in:
Denis Arh
2022-01-05 08:45:32 +01:00
parent 60020f8510
commit 72999ca692
20 changed files with 310 additions and 296 deletions

View File

@@ -91,7 +91,7 @@ func Command(ctx context.Context, app serviceInitializer, storeInit func(ctx con
Run: func(cmd *cobra.Command, args []string) {
ctx = auth.SetIdentityToContext(ctx, auth.ServiceUser())
var (
at string
at []byte
user *types.User
err error
@@ -104,9 +104,9 @@ func Command(ctx context.Context, app serviceInitializer, storeInit func(ctx con
err = service.DefaultAuth.LoadRoleMemberships(ctx, user)
cli.HandleError(err)
at, err = auth.DefaultJwtHandler.Generate(ctx, user)
at, err = auth.DefaultJwtHandler.Generate(ctx, user, 0)
cli.HandleError(err)
cmd.Println(at)
cmd.Println(string(at))
},
}

View File

@@ -4,94 +4,55 @@ import (
"context"
"strings"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/payload"
"github.com/cortezaproject/corteza-server/pkg/rand"
"github.com/dgrijalva/jwt-go"
"github.com/go-oauth2/oauth2/v4"
"github.com/go-oauth2/oauth2/v4/errors"
"github.com/spf13/cast"
)
// JWTAccessGenerate generate the jwt access token
type (
JWTAccessGenerate struct {
tm auth.TokenGenerator
}
)
// NewJWTAccessGenerate create to generate the jwt access token instance
//
// @todo move this to pkg/auth (??) so it can be re-used
func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
return &JWTAccessGenerate{
SignedKeyID: kid,
SignedKey: key,
SignedMethod: method,
}
}
// JWTAccessGenerate generate the jwt access token
type JWTAccessGenerate struct {
SignedKeyID string
SignedKey []byte
SignedMethod jwt.SigningMethod
func NewJWTAccessGenerate(tg auth.TokenGenerator) *JWTAccessGenerate {
return &JWTAccessGenerate{tg}
}
// Token based on the UUID generated token
func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
// extract user ID and roles from a space-delimited list of IDs stored in userID
userIdWithRoles := strings.SplitN(data.TokenInfo.GetUserID(), " ", 2)
if len(userIdWithRoles) == 1 {
userIdWithRoles = append(userIdWithRoles, "")
}
func (a *JWTAccessGenerate) Token(_ context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (_ string, refresh string, err error) {
var (
user auth.Identifiable
rawToken []byte
)
// using jwt.MapClaims is good enough, it's validation rules ae
claims := jwt.MapClaims{
"aud": data.Client.GetID(),
"sub": userIdWithRoles[0],
"exp": data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn()).Unix(),
"scope": data.TokenInfo.GetScope(),
"roles": userIdWithRoles[1],
}
token := jwt.NewWithClaims(a.SignedMethod, claims)
token.Header["salt"] = string(rand.Bytes(32))
if a.SignedKeyID != "" {
token.Header["kid"] = a.SignedKeyID
}
var key interface{}
if a.isEs() {
v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
{
// extract user ID and roles from a space-delimited list of IDs stored in userID
userIdWithRoles := strings.Split(data.TokenInfo.GetUserID(), " ")
if len(userIdWithRoles) == 1 {
user = auth.Authenticated(cast.ToUint64(userIdWithRoles[0]))
} else {
user = auth.Authenticated(
cast.ToUint64(userIdWithRoles[0]),
payload.ParseUint64s(userIdWithRoles)...,
)
}
key = v
} else if a.isRsOrPS() {
v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isHs() {
key = a.SignedKey
} else {
return "", "", errors.New("unsupported sign method")
}
access, err := token.SignedString(key)
rawToken, err = a.tm.Encode(user, cast.ToUint64(data.Client.GetID()), data.TokenInfo.GetScope())
if err != nil {
return "", "", err
return
}
refresh := ""
if isGenRefresh {
refresh = string(rand.Bytes(48))
}
return access, refresh, nil
}
func (a *JWTAccessGenerate) isEs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
}
func (a *JWTAccessGenerate) isRsOrPS() bool {
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
return isRs || isPs
}
func (a *JWTAccessGenerate) isHs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
return string(rawToken), refresh, nil
}

View File

@@ -3,9 +3,9 @@ package oauth2
import (
"strings"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/logger"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/dgrijalva/jwt-go"
"github.com/go-oauth2/oauth2/v4"
"github.com/go-oauth2/oauth2/v4/errors"
"github.com/go-oauth2/oauth2/v4/manage"
@@ -32,7 +32,7 @@ func NewManager(opt options.AuthOpt, log *zap.Logger, cs oauth2.ClientStore, ts
manager.MapTokenStorage(ts)
// generate jwt access token
manager.MapAccessGenerate(NewJWTAccessGenerate("", []byte(opt.Secret), jwt.SigningMethodHS512))
manager.MapAccessGenerate(NewJWTAccessGenerate(auth.DefaultJwtHandler))
manager.MapClientStorage(cs)
manager.SetValidateURIHandler(func(baseURI, redirectURI string) (err error) {

View File

@@ -2,13 +2,13 @@ package automation
import (
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"strings"
"github.com/dgrijalva/jwt-go"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
)
type (
@@ -28,8 +28,6 @@ func JwtHandler(reg jwtHandlerRegistry) *jwtHandler {
func (h jwtHandler) generate(ctx context.Context, args *jwtGenerateArgs) (res *jwtGenerateResults, err error) {
var (
secret interface{}
auxp = make(map[string]interface{})
auxh = make(map[string]interface{})
)
@@ -69,28 +67,40 @@ func (h jwtHandler) generate(ctx context.Context, args *jwtGenerateArgs) (res *j
return r == ' ' || r == ','
})
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(auxp))
var (
tkn = jwt.New()
keySet jwk.Set
tokenBytes []byte
)
// merge header with user-provided header
for k, v := range auxh {
token.Header[k] = v
}
// check if we use cert
{
pemBlock, _ := pem.Decode([]byte(args.secretString))
if pemBlock != nil {
if secret, err = x509.ParsePKCS8PrivateKey(pemBlock.Bytes); err != nil {
return
}
} else {
secret = []byte(args.secretString)
for k, v := range auxp {
if err = tkn.Set(k, v); err != nil {
return
}
}
res = &jwtGenerateResults{}
res.Token, err = token.SignedString(secret)
//< HEAD
// // check if we use cert
// {
// pemBlock, _ := pem.Decode([]byte(args.secretString))
//
// if pemBlock != nil {
// if secret, err = x509.ParsePKCS8PrivateKey(pemBlock.Bytes); err != nil {
// return
// }
// } else {
// secret = []byte(args.secretString)
// }
//=
// @todo check if jwk.Parse provides the same logic as before with pem.Decode and x59.ParsePkC8PrivateKey
if keySet, err = jwk.Parse([]byte(args.secretString)); err != nil {
return
//> e3a304d5... Replacing dgrijalva/jwt-go with lestrrat-go/jwx
}
return
if tokenBytes, err = jwt.Sign(tkn, jwa.HS512, keySet); err != nil {
return
}
return &jwtGenerateResults{Token: string(tokenBytes)}, nil
}

View File

@@ -290,9 +290,9 @@ func (svc node) Pair(ctx context.Context, nodeID uint64) error {
return err
}
var accessToken string
var accessToken []byte
// Generate JWT token for the federated user
accessToken, err = svc.tokenEncoder.Generate(ctx, u)
accessToken, err = svc.tokenEncoder.Generate(ctx, u, 0)
if err != nil {
return err
}
@@ -301,7 +301,7 @@ func (svc node) Pair(ctx context.Context, nodeID uint64) error {
n.UpdatedAt = now()
// Start handshake initialization remote node
if err = svc.handshaker.Init(ctx, n, accessToken); err != nil {
if err = svc.handshaker.Init(ctx, n, string(accessToken)); err != nil {
return err
}
@@ -361,15 +361,15 @@ func (svc node) HandshakeConfirm(ctx context.Context, nodeID uint64) error {
return err
}
var accessToken string
var accessToken []byte
// Generate JWT token for the federated user
accessToken, err = svc.tokenEncoder.Generate(ctx, u)
accessToken, err = svc.tokenEncoder.Generate(ctx, u, 0)
n.UpdatedBy = auth.GetIdentityFromContext(ctx).Identity()
n.UpdatedAt = now()
// Complete handshake on remote node
if err = svc.handshaker.Complete(ctx, n, accessToken); err != nil {
if err = svc.handshaker.Complete(ctx, n, string(accessToken)); err != nil {
return err
}

View File

@@ -3,8 +3,6 @@ package auth
import (
"context"
"net/http"
"github.com/dgrijalva/jwt-go"
)
type (
@@ -16,12 +14,12 @@ type (
}
TokenGenerator interface {
Generate(ctx context.Context, identity Identifiable) (string, error)
Encode(i Identifiable, clientID uint64, scope ...string) (token []byte, err error)
Generate(ctx context.Context, i Identifiable, clientID uint64, scope ...string) (token []byte, err error)
}
TokenHandler interface {
TokenGenerator
Authenticate(token string) (jwt.MapClaims, error)
HttpVerifier() func(http.Handler) http.Handler
HttpAuthenticator() func(http.Handler) http.Handler
}

View File

@@ -4,27 +4,25 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/cortezaproject/corteza-server/pkg/id"
"github.com/cortezaproject/corteza-server/pkg/rand"
"github.com/cortezaproject/corteza-server/pkg/payload"
"github.com/cortezaproject/corteza-server/system/types"
"github.com/cortezaproject/corteza-server/pkg/api"
"github.com/dgrijalva/jwt-go"
"github.com/go-chi/jwtauth"
"github.com/pkg/errors"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"github.com/spf13/cast"
)
type (
token struct {
tokenManager struct {
// Expiration time in minutes
expiry time.Duration
tokenAuth *jwtauth.JWTAuth
secret []byte
expiry time.Duration
signAlgo jwa.SignatureAlgorithm
signKey jwk.Set
}
tokenStore interface {
@@ -52,21 +50,50 @@ var (
func SetupDefault(secret string, expiry time.Duration) {
// Use JWT secret for hmac signer for now
DefaultSigner = HmacSigner(secret)
DefaultJwtHandler, _ = JWT(secret, expiry)
DefaultJwtHandler, _ = TokenManager(secret, expiry)
}
func JWT(secret string, expiry time.Duration) (tkn *token, err error) {
func TokenManager(secret string, expiry time.Duration) (*tokenManager, error) {
var (
err error
set jwk.Set
)
if len(secret) == 0 {
return nil, errors.New("JWT secret missing")
return nil, fmt.Errorf("JWT secret missing")
}
tkn = &token{
expiry: expiry,
tokenAuth: jwtauth.New(jwt.SigningMethodHS512.Alg(), []byte(secret), nil),
secret: []byte(secret),
// @todo jwk.Parse can accept other input types beside byte-slice
// we could use it to strength Corteza's security
if set, err = jwk.Parse([]byte(secret)); err != nil {
return nil, err
}
return tkn, nil
return &tokenManager{
expiry: expiry,
signAlgo: jwa.HS512,
signKey: set,
}, nil
//
//var (
// // tuukn = jwt.New()
// // signed []byte
// //)
// //
// //if err = tuukn.Set(jwt.ExpirationKey, expiry); err != nil {
// // return
// //}
// //
// signed, err = jwt.Sign(tuukn, jwa.HS512, []byte(secret))
//
// tkn = &tokenManager{
// expiry: expiry,
// tokenAuth: jwtauth.New(jwt.SigningMethodHS512.Alg(), []byte(secret), nil),
// secret: []byte(secret),
// }, nil
//
//return tkn, nil
}
// SetJWTStore set store for JWT
@@ -76,36 +103,43 @@ func SetJWTStore(store tokenStore) {
DefaultJwtStore = store
}
func (t *token) Authenticate(token string) (jwt.MapClaims, error) {
dt, err := t.tokenAuth.Decode(token)
if err != nil {
return nil, err
// Authenticate the tokej from the given string and return parsed token or error
func (tm *tokenManager) Authenticate(token string) (pToken jwt.Token, err error) {
if pToken, err = jwt.Parse([]byte(token), jwt.WithKeySet(tm.signKey)); err != nil {
return
}
if dt == nil || !dt.Valid {
return nil, jwtauth.ErrUnauthorized
if err = jwt.Validate(pToken); err != nil {
return
}
if dt.Method != jwt.SigningMethodHS512 {
return nil, jwtauth.ErrAlgoInvalid
}
if mc, is := dt.Claims.(jwt.MapClaims); is {
return mc, nil
}
return nil, nil
return
}
// HttpVerifier returns a HTTP handler that verifies JWT and stores it into context
func (t *token) HttpVerifier() func(http.Handler) http.Handler {
return jwtauth.Verifier(t.tokenAuth)
}
//// Encode identity into a
//func (tm *tokenManager) Encode(identity Identifiable, scope ...string) ([]byte, error) {
// var (
// // when possible, extend this with the client
// clientID uint64 = 0
// )
//
// if len(scope) == 0 {
// // for backward compatibility we default
// // unset scope to profile & api
// scope = []string{"profile", "api"}
// }
//
// return tm.Encode(identity, clientID, scope...)
//}
func (t *token) Encode(i Identifiable, scope ...string) string {
// Encode give identity, clientID & scope into JWT access token (that can be use for API requests)
//
// @todo this follows implementation in auth/oauth2/jwt_access.go
// and should be refactored accordingly (move both into the same location/pkg => here)
func (tm *tokenManager) Encode(identity Identifiable, clientID uint64, scope ...string) (_ []byte, err error) {
var (
// when possible, extend this with the client
clientID uint64 = 0
token = jwt.New()
roles = ""
)
if len(scope) == 0 {
@@ -114,62 +148,98 @@ func (t *token) Encode(i Identifiable, scope ...string) string {
scope = []string{"profile", "api"}
}
return t.encode(i, clientID, scope...)
}
// encode give identity, clientID & scope into JWT access token (that can be use for API requests)
//
// @todo this follows implementation in auth/oauth2/jwt_access.go
// and should be refactored accordingly (move both into the same location/pkg => here)
func (t *token) encode(i Identifiable, clientID uint64, scope ...string) string {
roles := ""
for _, r := range i.Roles() {
for _, r := range identity.Roles() {
roles += fmt.Sprintf(" %d", r)
}
claims := jwt.MapClaims{
"sub": i.String(),
"exp": time.Now().Add(t.expiry).Unix(),
"aud": fmt.Sprintf("%d", clientID),
"scope": strings.Join(scope, " "),
"roles": strings.TrimSpace(roles),
// previous implementation had special a "salt" claim that ensured JWT uniquness
// we're using more standard approach with JWT ID now.
if err = token.Set(jwt.JwtIDKey, id.Next()); err != nil {
return
}
newToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
newToken.Header["salt"] = string(rand.Bytes(32))
access, _ := newToken.SignedString(t.secret)
return access
}
// HttpAuthenticator converts JWT claims into identity and stores it into context
func (t *token) HttpAuthenticator() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tkn, claims, err := jwtauth.FromContext(ctx)
// When token is present, expect no errors and valid claims!
if tkn != nil {
if err != nil {
// But if token is present, the shouldn't be an error
api.Send(w, r, err)
return
}
ctx = SetIdentityToContext(ctx, ClaimsToIdentity(claims))
ctx = context.WithValue(ctx, scopeCtxKey{}, claims["scope"])
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
if err = token.Set(jwt.SubjectKey, identity.String()); err != nil {
return
}
if err = token.Set(jwt.ExpirationKey, time.Now().Add(tm.expiry).Unix()); err != nil {
return
}
if err = token.Set(jwt.AudienceKey, fmt.Sprintf("%d", clientID)); err != nil {
return
}
if err = token.Set("scope", strings.Join(scope, " ")); err != nil {
return
}
if err = token.Set("roles", strings.TrimSpace(roles)); err != nil {
return
}
return jwt.Sign(token, tm.signAlgo, tm.signKey)
//claims := jwt.MapClaims{
// "sub": identity.String(),
// "exp": time.Now().Add(tm.expiry).Unix(),
// "aud": fmt.Sprintf("%d", clientID),
// "scope": strings.Join(scope, " "),
// "roles": strings.TrimSpace(roles),
//}
//
//newToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
//newToken.Header["salt"] = string(rand.Bytes(32))
//access, _ := newToken.SignedString(tm.secret)
//return access
}
func (t *token) Generate(ctx context.Context, i Identifiable) (tokenString string, err error) {
//// HttpVerifier returns a HTTP handler that verifies JWT and stores it into context
//func (t *tokenManager) HttpVerifier() func(http.Handler) http.Handler {
// //jwt.WithHTTPClient()
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// token, err := jwt.ParseRequest(req)
// if err != nil {
//
// }
//
// next.ServeHTTP(w, req)
// })
// }
//
// return jwtauth.Verifier(t.tokenAuth)
//}
//// HttpAuthenticator converts JWT claims into identity and stores it into context
//func (tm *tokenManager) HttpAuthenticator() func(http.Handler) http.Handler {
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
//
// tkn, claims, err := jwtauth.FromContext(ctx)
//
// // When token is present, expect no errors and valid claims!
// if tkn != nil {
// if err != nil {
// // But if token is present, the shouldn't be an error
// api.Send(w, r, err)
// return
// }
//
// ctx = SetIdentityToContext(ctx, ClaimsToIdentity(claims))
// ctx = context.WithValue(ctx, scopeCtxKey{}, claims["scope"])
//
// r = r.WithContext(ctx)
// }
//
// next.ServeHTTP(w, r)
// })
// }
//}
// Generates JWT and stores alongside with client-confirmation entry,
func (tm *tokenManager) Generate(ctx context.Context, i Identifiable, clientID uint64, scope ...string) (token []byte, err error) {
var (
eti = GetExtraReqInfoFromContext(ctx)
oa2t = &types.AuthOa2token{
@@ -177,30 +247,31 @@ func (t *token) Generate(ctx context.Context, i Identifiable) (tokenString strin
CreatedAt: time.Now().Round(time.Second),
RemoteAddr: eti.RemoteAddr,
UserAgent: eti.UserAgent,
ClientID: clientID,
}
acc = &types.AuthConfirmedClient{
ConfirmedAt: oa2t.CreatedAt,
ClientID: clientID,
}
)
tokenString = t.Encode(i)
oa2t.Access = tokenString
oa2t.ExpiresAt = oa2t.CreatedAt.Add(t.expiry)
if token, err = tm.Encode(i, clientID, scope...); err != nil {
return
}
oa2t.Access = string(token)
// use the same expiration as on token
oa2t.ExpiresAt = oa2t.CreatedAt.Add(tm.expiry)
if oa2t.Data, err = json.Marshal(oa2t); err != nil {
return
}
// extend this with the client
oa2t.ClientID = 0
// copy client id to auth client confirmation
acc.ClientID = oa2t.ClientID
if oa2t.UserID, _ = ExtractFromSubClaim(i.String()); oa2t.UserID == 0 {
// UserID stores collection of IDs: user's ID and set of all roles' user is member of
return "", fmt.Errorf("could not parse user ID from token")
return nil, fmt.Errorf("could not parse user ID from token")
}
// copy user id to auth client confirmation
@@ -210,7 +281,7 @@ func (t *token) Generate(ctx context.Context, i Identifiable) (tokenString strin
return
}
return tokenString, DefaultJwtStore.CreateAuthOa2token(ctx, oa2t)
return token, DefaultJwtStore.CreateAuthOa2token(ctx, oa2t)
}
func GetExtraReqInfoFromContext(ctx context.Context) ExtraReqInfo {
@@ -223,40 +294,13 @@ func GetExtraReqInfoFromContext(ctx context.Context) ExtraReqInfo {
}
// ClaimsToIdentity decodes sub & roles claims into identity
func ClaimsToIdentity(c jwt.MapClaims) (i *identity) {
func IdentityFromToken(token jwt.Token) *identity {
var (
aux interface{}
ok bool
id, roles string
roles, _ = token.Get("roles")
)
if aux, ok = c["sub"]; !ok {
return
}
i = &identity{}
if id, ok = aux.(string); ok {
i.id, _ = strconv.ParseUint(id, 10, 64)
}
if i.id == 0 {
// pointless to decode roles if id is 0
return nil
}
if aux, ok = c["roles"]; !ok {
return
}
if roles, ok = aux.(string); !ok {
return
}
for _, role := range strings.Split(roles, " ") {
if roleID, _ := strconv.ParseUint(role, 10, 64); roleID != 0 {
i.memberOf = append(i.memberOf, roleID)
}
}
return Authenticated(i.id, i.memberOf...)
return Authenticated(
cast.ToUint64(token.Subject()),
payload.ParseUint64s(strings.Split(cast.ToString(roles), " "))...,
)
}

View File

@@ -16,6 +16,8 @@ func AccessTokenCheck(scope ...string) func(http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx = r.Context()
jwtauth.Authenticator()
// retrieve token and claims from context
tkn, _, err := jwtauth.FromContext(ctx)
if err != nil || !tkn.Valid {

View File

@@ -662,7 +662,7 @@ func (svc service) exec(ctx context.Context, script string, runAs string, args S
zap.String("resourceType", args.ResourceType()),
)
token string
token []byte
)
// Returns context with identity set to service user
@@ -735,12 +735,12 @@ func (svc service) exec(ctx context.Context, script string, runAs string, args S
}
// Generate and save the token
token, err = svc.authTokenMaker.Generate(ctx, definer)
token, err = svc.authTokenMaker.Generate(ctx, definer, 0)
if err != nil {
return
}
if err = encodeArguments(req.Args, "authToken", token); err != nil {
if err = encodeArguments(req.Args, "authToken", string(token)); err != nil {
return
}
@@ -753,12 +753,12 @@ func (svc service) exec(ctx context.Context, script string, runAs string, args S
}
// Generate and save the token
token, err = svc.authTokenMaker.Generate(ctx, invoker)
token, err = svc.authTokenMaker.Generate(ctx, invoker, 0)
if err != nil {
return
}
if err = encodeArguments(req.Args, "authToken", token); err != nil {
if err = encodeArguments(req.Args, "authToken", string(token)); err != nil {
return
}
}

View File

@@ -5,15 +5,26 @@ import (
"net/http"
"sync"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/errors"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/cortezaproject/corteza-server/pkg/slice"
"github.com/dgrijalva/jwt-go"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
type (
server struct {
config options.WebsocketOpt
logger *zap.Logger
// user id => session id => session
sessions map[uint64]map[uint64]io.Writer
// keep lock on session map changes
l sync.RWMutex
}
)
var (
// upgrader handles websocket requests from peers
upgrader = websocket.Upgrader{
@@ -25,33 +36,15 @@ var (
}
)
type (
server struct {
config options.WebsocketOpt
logger *zap.Logger
// user id => session id => session
sessions map[uint64]map[uint64]io.Writer
accessToken interface {
Authenticate(string) (jwt.MapClaims, error)
}
// keep lock on session map changes
l sync.RWMutex
}
)
func Server(logger *zap.Logger, config options.WebsocketOpt) *server {
if !config.LogEnabled {
logger = zap.NewNop()
}
return &server{
config: config,
logger: logger.Named("websocket"),
accessToken: auth.DefaultJwtHandler,
sessions: make(map[uint64]map[uint64]io.Writer),
config: config,
logger: logger.Named("websocket"),
sessions: make(map[uint64]map[uint64]io.Writer),
}
}

View File

@@ -4,15 +4,17 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/errors"
"github.com/cortezaproject/corteza-server/pkg/id"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/gorilla/websocket"
"github.com/lestrrat-go/jwx/jwt"
"go.uber.org/zap"
"net"
"sync"
"time"
)
// active sessions of users
@@ -285,17 +287,21 @@ func (s *session) writeLoop() error {
}
func (s *session) authenticate(p *payloadAuth) error {
claims, err := s.server.accessToken.Authenticate(p.AccessToken)
token, err := jwt.Parse([]byte(p.AccessToken))
if err != nil {
return err
}
if !auth.CheckScope(claims["scope"], "api") {
if err = jwt.Validate(token); err != nil {
return err
}
if scope, has := token.Get("scope"); !has || !auth.CheckScope(scope, "api") {
return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)")
}
// Get identity using JWT claims
identity := auth.ClaimsToIdentity(claims)
identity := auth.IdentityFromToken(token)
if s.identity != nil {
if s.identity.Identity() != identity.Identity() {
@@ -308,7 +314,7 @@ func (s *session) authenticate(p *payloadAuth) error {
}
s.identity = identity
s.Write([]byte(ok))
_, _ = s.Write(ok)
return nil
}

View File

@@ -2,6 +2,7 @@ package rest
import (
"context"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/payload"
"github.com/cortezaproject/corteza-server/system/rest/request"
@@ -70,13 +71,13 @@ func (ctrl *Auth) makePayload(ctx context.Context, user *types.User) (*authUserR
}
// Generate and save the token
t, err := ctrl.tokenHandler.Generate(ctx, user)
t, err := ctrl.tokenHandler.Generate(ctx, user, 0)
if err != nil {
return nil, nil
}
return &authUserResponse{
JWT: t,
JWT: string(t),
User: &authUserPayload{
userPayload: &userPayload{
ID: user.ID,

View File

@@ -43,7 +43,7 @@ type (
cUser *sysTypes.User
roleID uint64
token string
token []byte
}
)
@@ -111,7 +111,7 @@ func newHelper(t *testing.T) helper {
helpers.UpdateRBAC(h.roleID)
var err error
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser)
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser, 0)
if err != nil {
panic(err)
}

View File

@@ -37,7 +37,7 @@ type (
cUser *sysTypes.User
roleID uint64
token string
token []byte
}
)
@@ -107,7 +107,7 @@ func newHelper(t *testing.T) helper {
helpers.UpdateRBAC(h.roleID)
var err error
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser)
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser, 0)
if err != nil {
panic(err)
}

View File

@@ -87,7 +87,7 @@ func newHelper(t *testing.T) helper {
helpers.UpdateRBAC(h.roleID)
var err error
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser)
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser, 0)
if err != nil {
panic(err)
}

View File

@@ -171,7 +171,7 @@ func TestSuccessfulNodePairing(t *testing.T) {
//Debug().
// make sure we do not use test auth-token for authentication but
// we do it with pairing token
Intercept(helpers.ReqHeaderRawAuthBearer(n.AuthToken)).
Intercept(helpers.ReqHeaderRawAuthBearer([]byte(n.AuthToken))).
Post(fmt.Sprintf("/nodes/%d/handshake", n.SharedNodeID)).
FormData("pairToken", n.PairToken).
FormData("authToken", authToken).
@@ -207,7 +207,7 @@ func TestSuccessfulNodePairing(t *testing.T) {
//Debug().
// make sure we do not use test auth-token but
// one provided to us in the initial handshake step
Intercept(helpers.ReqHeaderRawAuthBearer(n.AuthToken)).
Intercept(helpers.ReqHeaderRawAuthBearer([]byte(n.AuthToken))).
Post(fmt.Sprintf("/nodes/%d/handshake-complete", n.SharedNodeID)).
FormData("authToken", authToken).
Expect(t).
@@ -221,7 +221,7 @@ func TestSuccessfulNodePairing(t *testing.T) {
h.apiInit().
//Debug().
Intercept(helpers.ReqHeaderRawAuthBearer(getNodeAuthToken(aNodeID))).
Intercept(helpers.ReqHeaderRawAuthBearer([]byte(getNodeAuthToken(aNodeID)))).
Post(fmt.Sprintf("/nodes/%d/handshake-confirm", aNodeID)).
Expect(t).
Status(http.StatusOK).

View File

@@ -16,8 +16,8 @@ func BindAuthMiddleware(r chi.Router) {
)
}
func ReqHeaderRawAuthBearer(token string) apitest.Intercept {
func ReqHeaderRawAuthBearer(token []byte) apitest.Intercept {
return func(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Authorization", "Bearer "+string(token))
}
}

View File

@@ -125,7 +125,7 @@ func newHelper(t *testing.T) helper {
helpers.UpdateRBAC(h.roleID)
var err error
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser)
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser, 0)
if err != nil {
panic(err)
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/cortezaproject/corteza-server/system/types"
s "github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/dgrijalva/jwt-go"
"github.com/steinfletcher/apitest"
"go.uber.org/zap"
)
@@ -92,10 +91,10 @@ func TestAuthExternalSAMLSuccess(t *testing.T) {
s.MaxClockSkew = time.Hour
s.MaxIssueDelay = time.Hour
jwt.TimeFunc = func() time.Time {
tm, _ := time.Parse("2006-01-2 15:04:05", "2021-05-17 09:17:10")
return tm
}
//jwt.TimeFunc = func() time.Time {
// tm, _ := time.Parse("2006-01-2 15:04:05", "2021-05-17 09:17:10")
// return tm
//}
s.TimeNow = func() time.Time {
tm, _ := time.Parse("2006-01-2 15:04:05", "2021-05-17 09:17:10")

View File

@@ -142,7 +142,7 @@ func newHelper(t *testing.T) helper {
h.mockPermissionsWithAccess()
var err error
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser)
h.token, err = auth.DefaultJwtHandler.Generate(context.Background(), h.cUser, 0)
if err != nil {
panic(err)
}