diff --git a/auth/db/schema/mysql/20180704080000.base.up.sql b/auth/db/schema/mysql/20180704080000.base.up.sql index c22172832..69b06eb75 100644 --- a/auth/db/schema/mysql/20180704080000.base.up.sql +++ b/auth/db/schema/mysql/20180704080000.base.up.sql @@ -15,6 +15,7 @@ CREATE TABLE users ( name TEXT NOT NULL, handle TEXT NOT NULL, meta JSON NOT NULL, + satosa_id CHAR(36) NULL, rel_organisation BIGINT UNSIGNED NOT NULL, @@ -25,3 +26,5 @@ CREATE TABLE users ( PRIMARY KEY (id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE UNIQUE INDEX uid_satosa ON users (satosa_id); diff --git a/auth/repository/user.go b/auth/repository/user.go index 247363eea..f3c609dc9 100644 --- a/auth/repository/user.go +++ b/auth/repository/user.go @@ -20,6 +20,7 @@ type ( FindUserByEmail(email string) (*types.User, error) FindUserByUsername(username string) (*types.User, error) FindUserByID(id uint64) (*types.User, error) + FindUserBySatosaID(id string) (*types.User, error) FindUsers(filter *types.UserFilter) ([]*types.User, error) CreateUser(mod *types.User) (*types.User, error) UpdateUser(mod *types.User) (*types.User, error) @@ -30,8 +31,11 @@ type ( ) const ( + sqlUserColumns = "id, email, username, password, name, handle, " + + "meta, satosa_id, rel_organisation, " + + "created_at, updated_at, suspended_at, deleted_at" sqlUserScope = "suspended_at IS NULL AND deleted_at IS NULL" - sqlUserSelect = "SELECT * FROM users WHERE " + sqlUserScope + sqlUserSelect = "SELECT " + sqlUserColumns + " FROM users WHERE " + sqlUserScope ErrUserNotFound = repositoryError("UserNotFound") ) @@ -47,21 +51,28 @@ func (r *user) With(ctx context.Context) User { } func (r *user) FindUserByUsername(username string) (*types.User, error) { - sql := "SELECT * FROM users WHERE username = ? AND " + sqlUserScope + sql := sqlUserSelect + " AND username = ?" mod := &types.User{} return mod, isFound(r.db().Get(mod, sql, username), mod.ID > 0, ErrUserNotFound) } +func (r *user) FindUserBySatosaID(satosaID string) (*types.User, error) { + sql := sqlUserSelect + " AND satosa_id = ?" + mod := &types.User{} + + return mod, isFound(r.db().Get(mod, sql, satosaID), mod.ID > 0, ErrUserNotFound) +} + func (r *user) FindUserByEmail(email string) (*types.User, error) { - sql := "SELECT * FROM users WHERE email = ? AND " + sqlUserScope + sql := sqlUserSelect + " AND email = ?" mod := &types.User{} return mod, isFound(r.db().Get(mod, sql, email), mod.ID > 0, ErrUserNotFound) } func (r *user) FindUserByID(id uint64) (*types.User, error) { - sql := "SELECT * FROM users WHERE id = ? AND " + sqlUserScope + sql := sqlUserSelect + " AND id = ?" mod := &types.User{} return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrUserNotFound) @@ -70,7 +81,7 @@ func (r *user) FindUserByID(id uint64) (*types.User, error) { func (r *user) FindUsers(filter *types.UserFilter) ([]*types.User, error) { rval := make([]*types.User, 0) params := make([]interface{}, 0) - sql := "SELECT * FROM users WHERE " + sqlUserScope + sql := sqlUserSelect if filter != nil { if filter.Query != "" { diff --git a/auth/rest/oidc.go b/auth/rest/oidc.go index 27d6cb2f5..a299db18b 100644 --- a/auth/rest/oidc.go +++ b/auth/rest/oidc.go @@ -34,6 +34,12 @@ type ( jwt jwtEncodeCookieSetter } + oidcProfile struct { + Email string `json:"email"` + Name string `json:"name"` + Sub string `json:"sub"` + } + jwtEncodeCookieSetter interface { auth.TokenEncoder SetCookie(w http.ResponseWriter, r *http.Request, identity auth.Identifiable) @@ -153,12 +159,15 @@ func (c *openIdConnect) HandleOAuth2Callback(w http.ResponseWriter, r *http.Requ } u, _ := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)) + p := &oidcProfile{} + u.Claims(p) var user = &types.User{ - Email: u.Email, + Email: p.Email, + Name: p.Name, } - if user, err = c.userService.FindOrCreate(user); err != nil { + if user, err = c.userService.With(ctx).FindOrCreate(user); err != nil { resputil.JSON(w, err) return } else { diff --git a/auth/service/user.go b/auth/service/user.go index 814bdb0fd..85f081004 100644 --- a/auth/service/user.go +++ b/auth/service/user.go @@ -68,15 +68,27 @@ func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) { return svc.repository.FindUsers(filter) } -// Finds if user with a specific email exists and returns it otherwise it creates a fresh one +// Finds if user with a specific satosa id exists and returns it otherwise it creates a fresh one func (svc *user) FindOrCreate(user *types.User) (out *types.User, err error) { //return out, svc.repository.DB().Transaction(func() error { - if out, err = svc.repository.FindUserByEmail(user.Email); err != repository.ErrUserNotFound { - return out, err - } else if out, err = svc.repository.CreateUser(user); err != nil { + out, err = svc.repository.FindUserBySatosaID(user.SatosaID) + + if err == repository.ErrUserNotFound { + out, err = svc.repository.CreateUser(user) return out, err } + if err != nil { + // FindUserBySatosaID error + return nil, err + } + + // @todo need to be more selective with fields we update... + out, err = svc.repository.UpdateUser(out) + if err != nil { + return nil, err + } + return out, nil //}) } diff --git a/auth/types/user.go b/auth/types/user.go index 38c509351..b4c690ea4 100644 --- a/auth/types/user.go +++ b/auth/types/user.go @@ -2,8 +2,9 @@ package types import ( "encoding/json" - "golang.org/x/crypto/bcrypt" "time" + + "golang.org/x/crypto/bcrypt" ) type ( @@ -13,6 +14,7 @@ type ( Email string `json:"email" db:"email"` Name string `json:"name" db:"name"` Handle string `json:"handle" db:"handle"` + SatosaID string `json:"satosaId" db:"satosa_id"` Meta json.RawMessage `json:"-" db:"meta"` OrganisationID uint64 `json:"organisationId" db:"rel_organisation"` Password []byte `json:"-" db:"password"`