diff --git a/system/internal/service/auth.go b/system/internal/service/auth.go index 3eddf4dda..554a33a23 100644 --- a/system/internal/service/auth.go +++ b/system/internal/service/auth.go @@ -31,7 +31,8 @@ type ( External(profile goth.User) (*types.User, error) - CheckPassword(email string, password []byte) (*types.User, error) + LocalSignUp(*types.User) (*types.User, error) + LocalSignIn(email string, password []byte) (*types.User, error) ChangePassword(user *types.User, password []byte) error CheckCredentials(credentialsID uint64, secret string) (*types.User, error) RevokeCredentialsByID(user *types.User, credentialsID uint64) error @@ -203,11 +204,64 @@ func (svc *auth) External(profile goth.User) (u *types.User, err error) { }) } -// CheckPassword verifies username/password combination +// LocalSignUp protocol +// +// Forgiving but strict: valid existing users get notified +func (svc auth) LocalSignUp(new *types.User) (u *types.User, err error) { + // @todo do settings allow us to create new users? + + if new == nil { + return nil, errors.New("invalid signup input") + } + + if err = svc.validateLocalSignUp(new.Email); err != nil { + return + } + + return u, svc.db.Transaction(func() error { + existing, err := svc.users.FindByEmail(new.Email) + + if err == nil && existing.Valid() { + // User already exists, but we're nice and we'll send this user an + // email that will help him to login + // + // @todo check if user is suspended + _ = existing.SuspendedAt + // @todo check if user is deleted + _ = existing.DeletedAt + // @todo otherwise send this user a nice email (someone, probably you....) + return nil + } + + if err != repository.ErrUserNotFound { + return errors.Wrap(err, "could not check existing emails") + } + + // Whitelisted user data to copy + u, err = svc.users.Create(&types.User{ + Email: new.Email, + Name: new.Name, + Username: new.Username, + Handle: new.Handle, + }) + + return errors.Wrap(err, "could not create new user") + }) +} + +func (svc auth) validateLocalSignUp(email string) (err error) { + if !reEmail.MatchString(email) { + return errors.New("invalid email format") + } + + return nil +} + +// LocalSignIn verifies username/password combination in the local credentials table // // Expects plain text password as an input -func (svc *auth) CheckPassword(email string, password []byte) (u *types.User, err error) { - if err = svc.validateCredentials(email, password); err != nil { +func (svc *auth) LocalSignIn(email string, password []byte) (u *types.User, err error) { + if err = svc.validateLocalSignIn(email, password); err != nil { return } @@ -217,7 +271,7 @@ func (svc *auth) CheckPassword(email string, password []byte) (u *types.User, er ) u, err = svc.users.FindByEmail(email) - if err != repository.ErrUserNotFound { + if err == repository.ErrUserNotFound { return errors.New("invalid username/password combination") } @@ -234,8 +288,8 @@ func (svc *auth) CheckPassword(email string, password []byte) (u *types.User, er }) } -// validateCredentials does basic format & length check -func (svc auth) validateCredentials(email string, password []byte) error { +// validateLocalSignIn does basic format & length check +func (svc auth) validateLocalSignIn(email string, password []byte) error { if !reEmail.MatchString(email) { return errors.New("invalid email format") } diff --git a/system/internal/service/auth_test.go b/system/internal/service/auth_test.go index c9eab757e..8eea9f3d4 100644 --- a/system/internal/service/auth_test.go +++ b/system/internal/service/auth_test.go @@ -1,5 +1,3 @@ -// +-build unit - package service import ( @@ -112,7 +110,7 @@ func TestAuth_External_NonExisting(t *testing.T) { } } -func Test_auth_validateCredentials(t *testing.T) { +func Test_auth_validateLocalSignIn(t *testing.T) { type args struct { email string password []byte @@ -130,8 +128,8 @@ func Test_auth_validateCredentials(t *testing.T) { svc := auth{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := svc.validateCredentials(tt.args.email, tt.args.password); (err != nil) != tt.wantErr { - t.Errorf("auth.validateCredentials() error = %v, wantErr %v", err, tt.wantErr) + if err := svc.validateLocalSignIn(tt.args.email, tt.args.password); (err != nil) != tt.wantErr { + t.Errorf("auth.validateLocalSignIn() error = %v, wantErr %v", err, tt.wantErr) } }) }