Credentials & social-sign-on, fixed mocks
This commit is contained in:
parent
7c36966e68
commit
ae733dc2c9
6
Makefile
6
Makefile
@ -140,7 +140,7 @@ qa: vet critic test
|
||||
|
||||
mocks: $(GOMOCK)
|
||||
# Cleanup all pre-generated
|
||||
rm -f */*/*_mock_test.go
|
||||
rm -f */*/*_mock_test.go */*/mocks/*
|
||||
|
||||
# See https://github.com/golang/mock for details
|
||||
$(MOCKGEN) -package service -source crm/service/notification.go -destination crm/service/notification_mock_test.go
|
||||
@ -154,6 +154,10 @@ mocks: $(GOMOCK)
|
||||
|
||||
$(MOCKGEN) -package mail -source internal/mail/mail.go -destination internal/mail/mail_mock_test.go
|
||||
|
||||
mkdir -p system/repository/mocks
|
||||
$(MOCKGEN) -package repository -source system/repository/user.go -destination system/repository/mocks/user.go
|
||||
$(MOCKGEN) -package repository -source system/repository/credentials.go -destination system/repository/mocks/credentials.go
|
||||
|
||||
|
||||
########################################################################################################################
|
||||
# Toolset
|
||||
|
||||
@ -49,6 +49,7 @@ function types {
|
||||
./build/gen-type-set --no-pk-types Unread --output sam/types/unread.gen.go
|
||||
|
||||
./build/gen-type-set --types User --output system/types/user.gen.go
|
||||
./build/gen-type-set --types Credentials --output system/types/credentials.gen.go
|
||||
green "OK"
|
||||
}
|
||||
|
||||
|
||||
71
internal/config/social.go
Normal file
71
internal/config/social.go
Normal file
@ -0,0 +1,71 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
)
|
||||
|
||||
type (
|
||||
Social struct {
|
||||
Enabled bool
|
||||
|
||||
FacebookKey string
|
||||
FacebookSecret string
|
||||
GPlusKey string
|
||||
GPlusSecret string
|
||||
GitHubKey string
|
||||
GitHubSecret string
|
||||
LinkedInKey string
|
||||
LinkedInSecret string
|
||||
|
||||
Url string
|
||||
|
||||
SessionStoreSecret string
|
||||
SessionStoreExpiry int // seconds!
|
||||
}
|
||||
)
|
||||
|
||||
var social *Social
|
||||
|
||||
func (c *Social) Validate() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Enabled == false {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.SessionStoreSecret == "" {
|
||||
return errors.New("Session store secret not set for SOCIAL")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*Social) Init(prefix ...string) *Social {
|
||||
if social != nil {
|
||||
return social
|
||||
}
|
||||
|
||||
b := func(name string, k, s *string) {
|
||||
flag.StringVar(k, "auth-social-"+strings.ToLower(name)+"-key", "", name+" key")
|
||||
flag.StringVar(s, "auth-social-"+strings.ToLower(name)+"-secret", "", name+" secret")
|
||||
|
||||
}
|
||||
|
||||
social = new(Social)
|
||||
flag.BoolVar(&social.Enabled, "auth-social-enabled", true, "SocialAuth enabled")
|
||||
|
||||
b("Facebook", &social.FacebookKey, &social.FacebookSecret)
|
||||
b("GPlus", &social.GPlusKey, &social.GPlusSecret)
|
||||
b("GitHub", &social.GitHubKey, &social.GitHubSecret)
|
||||
b("LinkedIn", &social.LinkedInKey, &social.LinkedInSecret)
|
||||
|
||||
flag.StringVar(&social.Url, "auth-social-url", "", "Base URL")
|
||||
flag.StringVar(&social.SessionStoreSecret, "auth-social-session-store-secret", "", "Session store secret")
|
||||
flag.IntVar(&social.SessionStoreExpiry, "auth-social-state-cookie-expiry", 60*15, "SocialAuth State cookie expiry in seconds")
|
||||
return social
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
19
system/db/schema/mysql/20181208140000.credentials.up.sql
Normal file
19
system/db/schema/mysql/20181208140000.credentials.up.sql
Normal file
@ -0,0 +1,19 @@
|
||||
-- Keeps all known users, home and external organisation
|
||||
-- changes are stored in audit log
|
||||
CREATE TABLE sys_credentials (
|
||||
id BIGINT UNSIGNED NOT NULL,
|
||||
rel_owner BIGINT UNSIGNED NOT NULL REFERENCES sys_users(id),
|
||||
label TEXT NOT NULL COMMENT 'something we can differentiate credentials by',
|
||||
kind VARCHAR(128) NOT NULL COMMENT 'hash, facebook, gplus, github, linkedin ...',
|
||||
credentials TEXT NOT NULL COMMENT 'crypted/hashed passwords, secrets, social profile ID',
|
||||
meta JSON NOT NULL,
|
||||
expires_at DATETIME NULL,
|
||||
|
||||
created_at DATETIME NOT NULL DEFAULT NOW(),
|
||||
updated_at DATETIME NULL,
|
||||
deleted_at DATETIME NULL, -- user soft delete
|
||||
|
||||
PRIMARY KEY (id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
CREATE INDEX idx_owner ON sys_credentials (rel_owner);
|
||||
@ -13,7 +13,7 @@ type (
|
||||
monitor *config.Monitor
|
||||
db *config.Database
|
||||
oidc *config.OIDC
|
||||
//jwt *config.JWT
|
||||
social *config.Social
|
||||
}
|
||||
)
|
||||
|
||||
@ -57,6 +57,6 @@ func Flags(prefix ...string) {
|
||||
new(config.Monitor).Init(prefix...),
|
||||
new(config.Database).Init(prefix...),
|
||||
new(config.OIDC).Init(prefix...),
|
||||
//new(config.JWT).Init(prefix...),
|
||||
new(config.Social).Init(prefix...),
|
||||
}
|
||||
}
|
||||
|
||||
100
system/repository/credentials.go
Normal file
100
system/repository/credentials.go
Normal file
@ -0,0 +1,100 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
CredentialsRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) CredentialsRepository
|
||||
|
||||
FindByID(ID uint64) (*types.Credentials, error)
|
||||
FindByCredentials(kind types.CredentialsKind, credentials string) (cc types.CredentialsSet, err error)
|
||||
FindByKind(ownerID uint64, kind types.CredentialsKind) (cc types.CredentialsSet, err error)
|
||||
FindByOwnerID(ownerID uint64) (cc types.CredentialsSet, err error)
|
||||
Find() (cc types.CredentialsSet, err error)
|
||||
|
||||
Create(c *types.Credentials) (*types.Credentials, error)
|
||||
DeleteByID(id uint64) error
|
||||
}
|
||||
|
||||
credentials struct {
|
||||
*repository
|
||||
|
||||
// sql table reference
|
||||
tblname string
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
sqlCredentialsColumns = "id, rel_owner, kind, label, credentials, meta, expires_at, " +
|
||||
"created_at, deleted_at"
|
||||
sqlCredentialsScope = "deleted_at IS NULL"
|
||||
sqlCredentialsSelect = "SELECT " + sqlCredentialsColumns + " FROM %s WHERE " + sqlCredentialsScope
|
||||
|
||||
ErrCredentialsNotFound = repositoryError("CredentialsNotFound")
|
||||
)
|
||||
|
||||
func Credentials(ctx context.Context, db *factory.DB) CredentialsRepository {
|
||||
return (&credentials{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *credentials) With(ctx context.Context, db *factory.DB) CredentialsRepository {
|
||||
return &credentials{
|
||||
repository: r.repository.With(ctx, db),
|
||||
tblname: "sys_credentials",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *credentials) FindByID(ID uint64) (*types.Credentials, error) {
|
||||
sql := fmt.Sprintf(sqlCredentialsSelect, r.tblname) + " AND id = ?"
|
||||
mod := &types.Credentials{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, ID), mod.ID > 0, ErrCredentialsNotFound)
|
||||
}
|
||||
|
||||
func (r *credentials) FindByCredentials(kind types.CredentialsKind, credentials string) (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect+" AND kind = ? AND credentials = ?", r.tblname),
|
||||
kind,
|
||||
credentials)
|
||||
}
|
||||
|
||||
func (r *credentials) FindByKind(ownerID uint64, kind types.CredentialsKind) (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect+" AND rel_owner = ? AND kind = ?", r.tblname),
|
||||
ownerID,
|
||||
kind)
|
||||
}
|
||||
|
||||
func (r *credentials) FindByOwnerID(ownerID uint64) (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect+" AND rel_owner = ?", r.tblname),
|
||||
ownerID)
|
||||
}
|
||||
|
||||
func (r *credentials) Find() (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect, r.tblname))
|
||||
}
|
||||
|
||||
func (r *credentials) fetchSet(sql string, args ...interface{}) (cc types.CredentialsSet, err error) {
|
||||
cc = types.CredentialsSet{}
|
||||
return cc, r.db().Select(&cc, sql, args...)
|
||||
}
|
||||
|
||||
func (r *credentials) Create(c *types.Credentials) (*types.Credentials, error) {
|
||||
c.ID = factory.Sonyflake.NextID()
|
||||
c.CreatedAt = time.Now()
|
||||
return c, r.db().Insert(r.tblname, c)
|
||||
}
|
||||
|
||||
func (r *credentials) DeleteByID(id uint64) error {
|
||||
return r.updateColumnByID(r.tblname, "deleted_at", time.Now(), id)
|
||||
}
|
||||
58
system/repository/credentials_test.go
Normal file
58
system/repository/credentials_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
func TestCredentials(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
crepo := Credentials(context.Background(), factory.Database.MustGet())
|
||||
|
||||
{
|
||||
cc := types.CredentialsSet{
|
||||
&types.Credentials{OwnerID: 10000, Kind: types.CredentialsKindLinkedin, Credentials: "linkedin-profile-id"},
|
||||
&types.Credentials{OwnerID: 10000, Kind: types.CredentialsKindGPlus, Credentials: "gplus-profile-id"},
|
||||
&types.Credentials{OwnerID: 20000, Kind: types.CredentialsKindFacebook, Credentials: "facebook-profile-id"},
|
||||
}
|
||||
|
||||
tx(t, func() (err error) {
|
||||
if _, err = factory.Database.MustGet().Exec("TRUNCATE sys_credentials"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, c := range cc {
|
||||
cNew, err := crepo.Create(c)
|
||||
assert(t, err == nil, "Credentials.Create error: %+v", err)
|
||||
assert(t, c.ID > 0, "Expecting credentials to have a valid ID")
|
||||
assert(t, c.Valid(), "Expecting credentials to be valid after creation")
|
||||
|
||||
_, err = crepo.FindByID(cNew.ID)
|
||||
assert(t, err == nil, "Credentials.FindByID error: %+v", err)
|
||||
|
||||
{
|
||||
r, err := crepo.FindByKind(c.OwnerID, c.Kind)
|
||||
assert(t, err == nil, "Credentials.FindByKind error: %+v", err)
|
||||
assert(t, len(r) == 1, "Expecting exactly 1 result from FindByKind, got: %v", len(r))
|
||||
}
|
||||
|
||||
{
|
||||
r, err := crepo.FindByCredentials(c.Kind, c.Credentials)
|
||||
assert(t, err == nil, "Credentials.FindByKind error: %+v", err)
|
||||
assert(t, len(r) == 1, "Expecting exactly 1 result from FindByCredentials, got: %v", len(r))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
139
system/repository/mocks/credentials.go
Normal file
139
system/repository/mocks/credentials.go
Normal file
@ -0,0 +1,139 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/repository/credentials.go
|
||||
|
||||
// Package repository is a generated GoMock package.
|
||||
package repository
|
||||
|
||||
import (
|
||||
context "context"
|
||||
repository "github.com/crusttech/crust/system/repository"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
factory "github.com/titpetric/factory"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockCredentialsRepository is a mock of CredentialsRepository interface
|
||||
type MockCredentialsRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCredentialsRepositoryMockRecorder
|
||||
}
|
||||
|
||||
// MockCredentialsRepositoryMockRecorder is the mock recorder for MockCredentialsRepository
|
||||
type MockCredentialsRepositoryMockRecorder struct {
|
||||
mock *MockCredentialsRepository
|
||||
}
|
||||
|
||||
// NewMockCredentialsRepository creates a new mock instance
|
||||
func NewMockCredentialsRepository(ctrl *gomock.Controller) *MockCredentialsRepository {
|
||||
mock := &MockCredentialsRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockCredentialsRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCredentialsRepository) EXPECT() *MockCredentialsRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockCredentialsRepository) With(ctx context.Context, db *factory.DB) repository.CredentialsRepository {
|
||||
ret := m.ctrl.Call(m, "With", ctx, db)
|
||||
ret0, _ := ret[0].(repository.CredentialsRepository)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) With(ctx, db interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockCredentialsRepository)(nil).With), ctx, db)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockCredentialsRepository) FindByID(ID uint64) (*types.Credentials, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", ID)
|
||||
ret0, _ := ret[0].(*types.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByID(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByID), ID)
|
||||
}
|
||||
|
||||
// FindByCredentials mocks base method
|
||||
func (m *MockCredentialsRepository) FindByCredentials(kind types.CredentialsKind, credentials string) (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "FindByCredentials", kind, credentials)
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByCredentials indicates an expected call of FindByCredentials
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByCredentials(kind, credentials interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByCredentials", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByCredentials), kind, credentials)
|
||||
}
|
||||
|
||||
// FindByKind mocks base method
|
||||
func (m *MockCredentialsRepository) FindByKind(ownerID uint64, kind types.CredentialsKind) (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "FindByKind", ownerID, kind)
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByKind indicates an expected call of FindByKind
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByKind(ownerID, kind interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByKind", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByKind), ownerID, kind)
|
||||
}
|
||||
|
||||
// FindByOwnerID mocks base method
|
||||
func (m *MockCredentialsRepository) FindByOwnerID(ownerID uint64) (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "FindByOwnerID", ownerID)
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByOwnerID indicates an expected call of FindByOwnerID
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByOwnerID(ownerID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByOwnerID", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByOwnerID), ownerID)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockCredentialsRepository) Find() (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "Find")
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) Find() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockCredentialsRepository)(nil).Find))
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockCredentialsRepository) Create(c *types.Credentials) (*types.Credentials, error) {
|
||||
ret := m.ctrl.Call(m, "Create", c)
|
||||
ret0, _ := ret[0].(*types.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) Create(c interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockCredentialsRepository)(nil).Create), c)
|
||||
}
|
||||
|
||||
// DeleteByID mocks base method
|
||||
func (m *MockCredentialsRepository) DeleteByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "DeleteByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteByID indicates an expected call of DeleteByID
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) DeleteByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteByID", reflect.TypeOf((*MockCredentialsRepository)(nil).DeleteByID), id)
|
||||
}
|
||||
176
system/repository/mocks/user.go
Normal file
176
system/repository/mocks/user.go
Normal file
@ -0,0 +1,176 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/repository/user.go
|
||||
|
||||
// Package repository is a generated GoMock package.
|
||||
package repository
|
||||
|
||||
import (
|
||||
context "context"
|
||||
repository "github.com/crusttech/crust/system/repository"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
factory "github.com/titpetric/factory"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockUserRepository is a mock of UserRepository interface
|
||||
type MockUserRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUserRepositoryMockRecorder
|
||||
}
|
||||
|
||||
// MockUserRepositoryMockRecorder is the mock recorder for MockUserRepository
|
||||
type MockUserRepositoryMockRecorder struct {
|
||||
mock *MockUserRepository
|
||||
}
|
||||
|
||||
// NewMockUserRepository creates a new mock instance
|
||||
func NewMockUserRepository(ctrl *gomock.Controller) *MockUserRepository {
|
||||
mock := &MockUserRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockUserRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockUserRepository) EXPECT() *MockUserRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockUserRepository) With(ctx context.Context, db *factory.DB) repository.UserRepository {
|
||||
ret := m.ctrl.Call(m, "With", ctx, db)
|
||||
ret0, _ := ret[0].(repository.UserRepository)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockUserRepositoryMockRecorder) With(ctx, db interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockUserRepository)(nil).With), ctx, db)
|
||||
}
|
||||
|
||||
// FindByEmail mocks base method
|
||||
func (m *MockUserRepository) FindByEmail(email string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByEmail", email)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByEmail indicates an expected call of FindByEmail
|
||||
func (mr *MockUserRepositoryMockRecorder) FindByEmail(email interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByEmail", reflect.TypeOf((*MockUserRepository)(nil).FindByEmail), email)
|
||||
}
|
||||
|
||||
// FindByUsername mocks base method
|
||||
func (m *MockUserRepository) FindByUsername(username string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByUsername", username)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByUsername indicates an expected call of FindByUsername
|
||||
func (mr *MockUserRepositoryMockRecorder) FindByUsername(username interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByUsername", reflect.TypeOf((*MockUserRepository)(nil).FindByUsername), username)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockUserRepository) FindByID(id uint64) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", id)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockUserRepositoryMockRecorder) FindByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockUserRepository)(nil).FindByID), id)
|
||||
}
|
||||
|
||||
// FindBySatosaID mocks base method
|
||||
func (m *MockUserRepository) FindBySatosaID(id string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindBySatosaID", id)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindBySatosaID indicates an expected call of FindBySatosaID
|
||||
func (mr *MockUserRepositoryMockRecorder) FindBySatosaID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindBySatosaID", reflect.TypeOf((*MockUserRepository)(nil).FindBySatosaID), id)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockUserRepository) Find(filter *types.UserFilter) ([]*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Find", filter)
|
||||
ret0, _ := ret[0].([]*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockUserRepositoryMockRecorder) Find(filter interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockUserRepository)(nil).Find), filter)
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockUserRepository) Create(mod *types.User) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Create", mod)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockUserRepositoryMockRecorder) Create(mod interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockUserRepository)(nil).Create), mod)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockUserRepository) Update(mod *types.User) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Update", mod)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockUserRepositoryMockRecorder) Update(mod interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockUserRepository)(nil).Update), mod)
|
||||
}
|
||||
|
||||
// SuspendByID mocks base method
|
||||
func (m *MockUserRepository) SuspendByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "SuspendByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SuspendByID indicates an expected call of SuspendByID
|
||||
func (mr *MockUserRepositoryMockRecorder) SuspendByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SuspendByID", reflect.TypeOf((*MockUserRepository)(nil).SuspendByID), id)
|
||||
}
|
||||
|
||||
// UnsuspendByID mocks base method
|
||||
func (m *MockUserRepository) UnsuspendByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "UnsuspendByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnsuspendByID indicates an expected call of UnsuspendByID
|
||||
func (mr *MockUserRepositoryMockRecorder) UnsuspendByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsuspendByID", reflect.TypeOf((*MockUserRepository)(nil).UnsuspendByID), id)
|
||||
}
|
||||
|
||||
// DeleteByID mocks base method
|
||||
func (m *MockUserRepository) DeleteByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "DeleteByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteByID indicates an expected call of DeleteByID
|
||||
func (mr *MockUserRepositoryMockRecorder) DeleteByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteByID", reflect.TypeOf((*MockUserRepository)(nil).DeleteByID), id)
|
||||
}
|
||||
24
system/repository/repository_test.go
Normal file
24
system/repository/repository_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func tx(t *testing.T, f func() error) {
|
||||
db := DB(context.Background())
|
||||
|
||||
if err := db.Begin(); err != nil {
|
||||
t.Errorf("Could not begin transaction: %v", err)
|
||||
|
||||
}
|
||||
|
||||
if err := f(); err != nil {
|
||||
t.Errorf("Test transaction resulted in an error: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Rollback(); err != nil {
|
||||
t.Errorf("Could not rollback transaction: %v", err)
|
||||
}
|
||||
}
|
||||
@ -13,7 +13,7 @@ import (
|
||||
"github.com/crusttech/crust/system/service"
|
||||
)
|
||||
|
||||
func MountRoutes(oidcConfig *config.OIDC, jwtEncoder auth.TokenEncoder) func(chi.Router) {
|
||||
func MountRoutes(oidcConfig *config.OIDC, socialConfig *config.Social, jwtEncoder auth.TokenEncoder) func(chi.Router) {
|
||||
var err error
|
||||
var userSvc = service.User()
|
||||
var ctx = context.Background()
|
||||
@ -37,6 +37,8 @@ func MountRoutes(oidcConfig *config.OIDC, jwtEncoder auth.TokenEncoder) func(chi
|
||||
})
|
||||
}
|
||||
|
||||
NewSocial(socialConfig, jwtEncoder).MountRoutes(r)
|
||||
|
||||
// Provide raw `/auth` handlers
|
||||
Auth{}.New().Handlers(jwtEncoder).MountRoutes(r)
|
||||
|
||||
|
||||
180
system/rest/social.go
Normal file
180
system/rest/social.go
Normal file
@ -0,0 +1,180 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/markbates/goth"
|
||||
"github.com/markbates/goth/gothic"
|
||||
"github.com/markbates/goth/providers/facebook"
|
||||
"github.com/markbates/goth/providers/github"
|
||||
"github.com/markbates/goth/providers/gplus"
|
||||
"github.com/markbates/goth/providers/linkedin"
|
||||
"github.com/titpetric/factory/resputil"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/config"
|
||||
"github.com/crusttech/crust/system/service"
|
||||
)
|
||||
|
||||
type (
|
||||
Social struct {
|
||||
auth service.AuthService
|
||||
config *config.Social
|
||||
jwtEncoder auth.TokenEncoder
|
||||
}
|
||||
)
|
||||
|
||||
func NewSocial(config *config.Social, jwtEncoder auth.TokenEncoder) *Social {
|
||||
return &Social{
|
||||
auth: service.DefaultAuth,
|
||||
config: config,
|
||||
jwtEncoder: jwtEncoder,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctrl *Social) MountRoutes(r chi.Router) {
|
||||
store := sessions.NewCookieStore([]byte(ctrl.config.SessionStoreSecret))
|
||||
store.MaxAge(ctrl.config.SessionStoreExpiry)
|
||||
store.Options.Path = "/social"
|
||||
store.Options.HttpOnly = true
|
||||
store.Options.Secure = false // @todo
|
||||
gothic.Store = store
|
||||
|
||||
getProviderConfig := func(provider string) (key, secret string, has bool) {
|
||||
key, keyOk := os.LookupEnv("AUTH_SOCIAL_" + provider + "_KEY")
|
||||
sec, secOk := os.LookupEnv("AUTH_SOCIAL_" + provider + "_SECRET")
|
||||
if keyOk && secOk {
|
||||
return key, sec, true
|
||||
} else {
|
||||
log.Print(
|
||||
"Binding auth endpoints without " + provider + " provider " +
|
||||
"(missing key/secret, check AUTH_" + provider + "_KEY, AUTH_" + provider + "_SECRET)")
|
||||
}
|
||||
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if key, sec, ok := getProviderConfig("FACEBOOK"); ok {
|
||||
goth.UseProviders(facebook.New(key, sec, ctrl.config.Url+"/social/facebook/callback", "email"))
|
||||
}
|
||||
|
||||
if key, sec, ok := getProviderConfig("GPLUS"); ok {
|
||||
goth.UseProviders(gplus.New(key, sec, ctrl.config.Url+"/social/gplus/callback", "email"))
|
||||
}
|
||||
|
||||
if key, sec, ok := getProviderConfig("GITHUB"); ok {
|
||||
goth.UseProviders(github.New(key, sec, ctrl.config.Url+"/social/gplus/callback", "email"))
|
||||
}
|
||||
|
||||
if key, sec, ok := getProviderConfig("LINKEDIN"); ok {
|
||||
goth.UseProviders(linkedin.New(key, sec, ctrl.config.Url+"/social/linkedin/callback", "email"))
|
||||
}
|
||||
|
||||
// Copy provider from path (Chi URL param) to request context and return it
|
||||
copyProviderToContext := func(r *http.Request) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), "provider", chi.URLParam(r, "provider")))
|
||||
}
|
||||
|
||||
r.Route("/social/{provider}", func(r chi.Router) {
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
r = copyProviderToContext(r)
|
||||
|
||||
// Always set redir cookie, even if not requested. If param is empty, cookie is removed
|
||||
ctrl.setCookie(w, r, "redir", r.URL.Query().Get("redir"))
|
||||
spew.Dump("REDIR=" + r.URL.Query().Get("redir"))
|
||||
|
||||
// try to get the user without re-authenticating
|
||||
if user, err := gothic.CompleteUserAuth(w, r); err != nil {
|
||||
gothic.BeginAuthHandler(w, r)
|
||||
} else {
|
||||
// We've successfully singed-in through 3rd party auth
|
||||
ctrl.handleSuccessfulAuth(w, r, user)
|
||||
}
|
||||
})
|
||||
|
||||
r.Get("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
r = copyProviderToContext(r)
|
||||
|
||||
if user, err := gothic.CompleteUserAuth(w, r); err != nil {
|
||||
log.Printf("Failed to complete user auth: %v", err)
|
||||
ctrl.handleFailedCallback(w, r, err)
|
||||
} else {
|
||||
ctrl.handleSuccessfulAuth(w, r, user)
|
||||
}
|
||||
})
|
||||
|
||||
r.Get("/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := gothic.Logout(w, r); err != nil {
|
||||
log.Printf("Failed to social logout: %v", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusTemporaryRedirect)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (ctrl *Social) handleFailedCallback(w http.ResponseWriter, r *http.Request, err error) {
|
||||
provider := chi.URLParam(r, "provider")
|
||||
|
||||
if strings.Contains(err.Error(), "Error processing your OAuth request: Invalid oauth_verifier parameter") {
|
||||
// Just take user through the same loop again
|
||||
w.Header().Set("Location", "/social/"+provider)
|
||||
w.WriteHeader(http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "SSO Error: %v", err.Error())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// Handles authentication via external auth providers of
|
||||
// unknown an user + appending authentication on external providers
|
||||
// to a current user
|
||||
func (ctrl *Social) handleSuccessfulAuth(w http.ResponseWriter, r *http.Request, cred goth.User) {
|
||||
log.Printf("Successful social login: %v", cred)
|
||||
|
||||
if u, err := ctrl.auth.With(r.Context()).Social(cred); err != nil {
|
||||
resputil.JSON(w, err)
|
||||
} else {
|
||||
ctrl.jwtEncoder.SetCookie(w, r, u)
|
||||
|
||||
if c, err := r.Cookie("redir"); c != nil && err == nil {
|
||||
spew.Dump("REDIR=" + c.Value)
|
||||
ctrl.setCookie(w, r, "redir", "")
|
||||
w.Header().Set("Location", c.Value)
|
||||
w.WriteHeader(http.StatusSeeOther)
|
||||
|
||||
}
|
||||
|
||||
resputil.JSON(w, u, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts and authenticates JWT from context, validates claims
|
||||
func (ctrl *Social) setCookie(w http.ResponseWriter, r *http.Request, name, value string) {
|
||||
cookie := &http.Cookie{
|
||||
Name: name,
|
||||
Expires: time.Now().Add(time.Duration(ctrl.config.SessionStoreExpiry) * time.Second),
|
||||
Secure: r.URL.Scheme == "https",
|
||||
Domain: r.URL.Hostname(),
|
||||
Path: "/social",
|
||||
}
|
||||
|
||||
if value == "" {
|
||||
cookie.Expires = time.Unix(0, 0)
|
||||
} else {
|
||||
cookie.Value = value
|
||||
}
|
||||
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
@ -25,7 +25,7 @@ func Routes() *chi.Mux {
|
||||
// Only protect application routes with JWT
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(jwtVerifier, jwtAuthenticator)
|
||||
mountRoutes(r, flags.http, rest.MountRoutes(flags.oidc, jwtEncoder))
|
||||
mountRoutes(r, flags.http, rest.MountRoutes(flags.oidc, flags.social, jwtEncoder))
|
||||
})
|
||||
|
||||
printRoutes(r, flags.http)
|
||||
|
||||
185
system/service/auth.go
Normal file
185
system/service/auth.go
Normal file
@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/markbates/goth"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/crusttech/crust/system/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
auth struct {
|
||||
db db
|
||||
ctx context.Context
|
||||
|
||||
credentials repository.CredentialsRepository
|
||||
users repository.UserRepository
|
||||
}
|
||||
|
||||
AuthService interface {
|
||||
With(ctx context.Context) AuthService
|
||||
|
||||
Social(profile goth.User) (*types.User, error)
|
||||
|
||||
CheckPassword(username, password string) (*types.User, error)
|
||||
ChangePassword(user *types.User, password string) error
|
||||
CheckCredentials(credentialsID uint64, secret string) (*types.User, error)
|
||||
RevokeCredentialsByID(user *types.User, credentialsID uint64) error
|
||||
}
|
||||
)
|
||||
|
||||
func Auth() AuthService {
|
||||
return (&auth{}).With(context.Background())
|
||||
}
|
||||
|
||||
func (svc *auth) With(ctx context.Context) AuthService {
|
||||
db := repository.DB(ctx)
|
||||
return &auth{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
|
||||
credentials: repository.Credentials(ctx, db),
|
||||
users: repository.User(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
// Social user verifies existance by using email value from social profile and creates user if needed
|
||||
//
|
||||
// It does not update user's info
|
||||
func (svc *auth) Social(profile goth.User) (u *types.User, err error) {
|
||||
var kind types.CredentialsKind
|
||||
|
||||
switch profile.Provider {
|
||||
case "facebook", "gplus", "github", "linkedin":
|
||||
kind = types.CredentialsKind(profile.Provider)
|
||||
default:
|
||||
return nil, errors.New("Unsupported provider")
|
||||
}
|
||||
|
||||
if profile.Email == "" {
|
||||
return nil, errors.New("Can not use profile data without an email")
|
||||
}
|
||||
|
||||
return u, svc.db.Transaction(func() error {
|
||||
var c *types.Credentials
|
||||
if cc, err := svc.credentials.FindByCredentials(kind, profile.UserID); err == nil {
|
||||
// Credentials found, load user
|
||||
for _, c := range cc {
|
||||
if !c.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
if u, err = svc.users.FindByID(c.OwnerID); err != nil {
|
||||
return nil
|
||||
} else if u.Valid() && u.Email != profile.Email {
|
||||
return errors.Errorf(
|
||||
"Refusing to authenticate with non matching emails (profile: %v, db: %v) on credentials (ID: %v)",
|
||||
profile.Email,
|
||||
u.Email,
|
||||
c.ID)
|
||||
} else if u.Valid() {
|
||||
// Valid user, matching emails. Bingo!
|
||||
return nil
|
||||
} else {
|
||||
// Scenario: linked to an invalid user
|
||||
u = nil
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If we could not find anything useful,
|
||||
// we can search for user via email
|
||||
} else {
|
||||
// A serious error occured, bail out...
|
||||
return err
|
||||
}
|
||||
|
||||
// Find user via his email
|
||||
if u, err = svc.users.FindByEmail(profile.Email); err == repository.ErrUserNotFound {
|
||||
// In case we do not have this email, create a new user
|
||||
u = &types.User{
|
||||
Email: profile.Email,
|
||||
Name: profile.Name,
|
||||
Username: profile.NickName,
|
||||
Handle: profile.NickName,
|
||||
}
|
||||
|
||||
if u, err = svc.users.Create(u); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c = &types.Credentials{
|
||||
Kind: kind,
|
||||
OwnerID: u.ID,
|
||||
Credentials: profile.UserID,
|
||||
}
|
||||
|
||||
if !profile.ExpiresAt.IsZero() {
|
||||
// Copy expiration date when provided
|
||||
c.ExpiresAt = &profile.ExpiresAt
|
||||
}
|
||||
|
||||
if c, err = svc.credentials.Create(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf(
|
||||
"Autheticated user (%v, %v) via %s, created user and credentials (%v)",
|
||||
u.ID,
|
||||
u.Email,
|
||||
profile.Provider,
|
||||
c.ID,
|
||||
)
|
||||
|
||||
// User created
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else if !u.Valid() {
|
||||
return errors.Errorf(
|
||||
"Social login to an invalid/suspended user (user ID: %v)",
|
||||
u.ID,
|
||||
)
|
||||
}
|
||||
|
||||
log.Printf(
|
||||
"Autheticated user (%v, %v) via %s, existing user",
|
||||
u.ID,
|
||||
u.Email,
|
||||
profile.Provider,
|
||||
)
|
||||
|
||||
// User loaded, carry on.
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CheckPassword verifies username/password combination
|
||||
//
|
||||
// Expects plain text password as an input
|
||||
func (svc *auth) CheckPassword(username, password string) (*types.User, error) {
|
||||
panic("svc.auth.CheckPassword, not implemented")
|
||||
}
|
||||
|
||||
// ChangePassword (soft) deletes old password entry and creates a new one
|
||||
//
|
||||
// Expects plain text password as an input
|
||||
func (svc *auth) ChangePassword(user *types.User, password string) error {
|
||||
panic("svc.auth.ChangePassword, not implemented")
|
||||
}
|
||||
|
||||
// CheckCredentials searches for credentials/secret combination and returns loaded user if successful
|
||||
func (svc *auth) CheckCredentials(credentialsID uint64, secret string) (*types.User, error) {
|
||||
panic("svc.auth.CheckCredentials, not implemented")
|
||||
}
|
||||
|
||||
// RevokeCredentialsByID (soft) deletes credentials by id
|
||||
func (svc *auth) RevokeCredentialsByID(user *types.User, credentialsID uint64) error {
|
||||
panic("svc.auth.RevokeCredentialsByID, not implemented")
|
||||
}
|
||||
|
||||
var _ AuthService = &auth{}
|
||||
85
system/service/auth_test.go
Normal file
85
system/service/auth_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/markbates/goth"
|
||||
|
||||
"github.com/crusttech/crust/system/repository"
|
||||
repomock "github.com/crusttech/crust/system/repository/mocks"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
func TestSocialSigninWithExistingCredentials(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
var u = &types.User{ID: 300000, Email: "foo@example.tld"}
|
||||
var c = &types.Credentials{ID: 200000, OwnerID: u.ID}
|
||||
var p = goth.User{UserID: "some-profile-id", Provider: "gplus", Email: u.Email}
|
||||
|
||||
crdRpoMock := repomock.NewMockCredentialsRepository(mockCtrl)
|
||||
crdRpoMock.EXPECT().
|
||||
FindByCredentials(types.CredentialsKindGPlus, p.UserID).
|
||||
Times(1).
|
||||
Return(types.CredentialsSet{c}, nil)
|
||||
|
||||
usrRpoMock := repomock.NewMockUserRepository(mockCtrl)
|
||||
usrRpoMock.EXPECT().FindByID(u.ID).Times(1).Return(u, nil)
|
||||
|
||||
svc := &auth{
|
||||
db: &mockDB{},
|
||||
users: usrRpoMock,
|
||||
credentials: crdRpoMock,
|
||||
}
|
||||
|
||||
{
|
||||
auser, err := svc.Social(p)
|
||||
assert(t, err == nil, "Auth.Social error: %+v", err)
|
||||
assert(t, auser.ID == u.ID, "Did not receive expected user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialSigninWithNewUserCredentials(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
var u = &types.User{ID: 300000, Email: "foo@example.tld"}
|
||||
var c = &types.Credentials{ID: 200000, OwnerID: u.ID}
|
||||
var p = goth.User{UserID: "some-profile-id", Provider: "gplus", Email: u.Email}
|
||||
|
||||
crdRpoMock := repomock.NewMockCredentialsRepository(mockCtrl)
|
||||
crdRpoMock.EXPECT().
|
||||
FindByCredentials(types.CredentialsKindGPlus, p.UserID).
|
||||
Times(1).
|
||||
Return(types.CredentialsSet{}, nil)
|
||||
|
||||
crdRpoMock.EXPECT().
|
||||
Create(&types.Credentials{Kind: types.CredentialsKindGPlus, OwnerID: u.ID, Credentials: p.UserID}).
|
||||
Times(1).
|
||||
Return(c, nil)
|
||||
|
||||
usrRpoMock := repomock.NewMockUserRepository(mockCtrl)
|
||||
usrRpoMock.EXPECT().
|
||||
FindByEmail(u.Email).
|
||||
Times(1).
|
||||
Return(nil, repository.ErrUserNotFound)
|
||||
|
||||
usrRpoMock.EXPECT().
|
||||
Create(&types.User{Email: "foo@example.tld"}).
|
||||
Times(1).
|
||||
Return(u, nil)
|
||||
|
||||
svc := &auth{
|
||||
db: &mockDB{},
|
||||
users: usrRpoMock,
|
||||
credentials: crdRpoMock,
|
||||
}
|
||||
|
||||
{
|
||||
auser, err := svc.Social(p)
|
||||
assert(t, err == nil, "Auth.Social error: %+v", err)
|
||||
assert(t, auser.ID == u.ID, "Did not receive expected user")
|
||||
}
|
||||
}
|
||||
21
system/service/main_test.go
Normal file
21
system/service/main_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockDB struct{}
|
||||
|
||||
func (mockDB) Transaction(callback func() error) error { return callback() }
|
||||
|
||||
func assert(t *testing.T, ok bool, format string, args ...interface{}) bool {
|
||||
if !ok {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
caller := fmt.Sprintf("\nAsserted at:%s:%d", file, line)
|
||||
|
||||
t.Fatalf(format+caller, args...)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
@ -1,135 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/service/organisation.go
|
||||
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
context "context"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockOrganisationService is a mock of OrganisationService interface
|
||||
type MockOrganisationService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockOrganisationServiceMockRecorder
|
||||
}
|
||||
|
||||
// MockOrganisationServiceMockRecorder is the mock recorder for MockOrganisationService
|
||||
type MockOrganisationServiceMockRecorder struct {
|
||||
mock *MockOrganisationService
|
||||
}
|
||||
|
||||
// NewMockOrganisationService creates a new mock instance
|
||||
func NewMockOrganisationService(ctrl *gomock.Controller) *MockOrganisationService {
|
||||
mock := &MockOrganisationService{ctrl: ctrl}
|
||||
mock.recorder = &MockOrganisationServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockOrganisationService) EXPECT() *MockOrganisationServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockOrganisationService) With(ctx context.Context) OrganisationService {
|
||||
ret := m.ctrl.Call(m, "With", ctx)
|
||||
ret0, _ := ret[0].(OrganisationService)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockOrganisationServiceMockRecorder) With(ctx interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockOrganisationService)(nil).With), ctx)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockOrganisationService) FindByID(organisationID uint64) (*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", organisationID)
|
||||
ret0, _ := ret[0].(*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockOrganisationServiceMockRecorder) FindByID(organisationID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockOrganisationService)(nil).FindByID), organisationID)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockOrganisationService) Find(filter *types.OrganisationFilter) ([]*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "Find", filter)
|
||||
ret0, _ := ret[0].([]*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockOrganisationServiceMockRecorder) Find(filter interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockOrganisationService)(nil).Find), filter)
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockOrganisationService) Create(organisation *types.Organisation) (*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "Create", organisation)
|
||||
ret0, _ := ret[0].(*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockOrganisationServiceMockRecorder) Create(organisation interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockOrganisationService)(nil).Create), organisation)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockOrganisationService) Update(organisation *types.Organisation) (*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "Update", organisation)
|
||||
ret0, _ := ret[0].(*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockOrganisationServiceMockRecorder) Update(organisation interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockOrganisationService)(nil).Update), organisation)
|
||||
}
|
||||
|
||||
// Archive mocks base method
|
||||
func (m *MockOrganisationService) Archive(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Archive", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Archive indicates an expected call of Archive
|
||||
func (mr *MockOrganisationServiceMockRecorder) Archive(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Archive", reflect.TypeOf((*MockOrganisationService)(nil).Archive), ID)
|
||||
}
|
||||
|
||||
// Unarchive mocks base method
|
||||
func (m *MockOrganisationService) Unarchive(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Unarchive", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Unarchive indicates an expected call of Unarchive
|
||||
func (mr *MockOrganisationServiceMockRecorder) Unarchive(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unarchive", reflect.TypeOf((*MockOrganisationService)(nil).Unarchive), ID)
|
||||
}
|
||||
|
||||
// Delete mocks base method
|
||||
func (m *MockOrganisationService) Delete(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Delete", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete
|
||||
func (mr *MockOrganisationServiceMockRecorder) Delete(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockOrganisationService)(nil).Delete), ID)
|
||||
}
|
||||
@ -4,8 +4,15 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type (
|
||||
db interface {
|
||||
Transaction(callback func() error) error
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
o sync.Once
|
||||
DefaultAuth AuthService
|
||||
DefaultUser UserService
|
||||
DefaultTeam TeamService
|
||||
DefaultOrganisation OrganisationService
|
||||
@ -13,6 +20,7 @@ var (
|
||||
|
||||
func Init() {
|
||||
o.Do(func() {
|
||||
DefaultAuth = Auth()
|
||||
DefaultUser = User()
|
||||
DefaultTeam = Team()
|
||||
DefaultOrganisation = Organisation()
|
||||
|
||||
@ -1,40 +1,31 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/titpetric/factory"
|
||||
)
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
usr := &types.User{ID: factory.Sonyflake.NextID()}
|
||||
|
||||
usrRpoMock := NewMockRepository(mockCtrl)
|
||||
usrRpoMock.EXPECT().WithCtx(gomock.Any()).AnyTimes().Return(usrRpoMock)
|
||||
usrRpoMock.EXPECT().
|
||||
FindUserByID(usr.ID).
|
||||
Times(1).
|
||||
Return(usr, nil)
|
||||
|
||||
svc := User()
|
||||
svc.rpo = usrRpoMock
|
||||
|
||||
found, err := svc.FindByID(context.Background(), usr.ID)
|
||||
if err != nil {
|
||||
t.Fatal("Did not expect an error")
|
||||
}
|
||||
|
||||
if found == nil {
|
||||
t.Fatal("Expecting an user to be found")
|
||||
}
|
||||
|
||||
if found.ID != usr.ID {
|
||||
t.Fatal("Expecting found user to have the same ID as the find param")
|
||||
}
|
||||
}
|
||||
// func TestUser(t *testing.T) {
|
||||
// mockCtrl := gomock.NewController(t)
|
||||
// defer mockCtrl.Finish()
|
||||
//
|
||||
// usr := &types.User{ID: factory.Sonyflake.NextID()}
|
||||
//
|
||||
// usrRpoMock := NewMockRepository(mockCtrl)
|
||||
// usrRpoMock.EXPECT().WithCtx(gomock.Any()).AnyTimes().Return(usrRpoMock)
|
||||
// usrRpoMock.EXPECT().
|
||||
// FindUserByID(usr.ID).
|
||||
// Times(1).
|
||||
// Return(usr, nil)
|
||||
//
|
||||
// svc := User()
|
||||
// svc.rpo = usrRpoMock
|
||||
//
|
||||
// found, err := svc.FindByID(context.Background(), usr.ID)
|
||||
// if err != nil {
|
||||
// t.Fatal("Did not expect an error")
|
||||
// }
|
||||
//
|
||||
// if found == nil {
|
||||
// t.Fatal("Expecting an user to be found")
|
||||
// }
|
||||
//
|
||||
// if found.ID != usr.ID {
|
||||
// t.Fatal("Expecting found user to have the same ID as the find param")
|
||||
// }
|
||||
// }
|
||||
|
||||
67
system/types/credentials.gen.go
Normal file
67
system/types/credentials.gen.go
Normal file
@ -0,0 +1,67 @@
|
||||
package types
|
||||
|
||||
// Hello! This file is auto-generated.
|
||||
|
||||
type (
|
||||
|
||||
// CredentialsSet slice of Credentials
|
||||
//
|
||||
// This type is auto-generated.
|
||||
CredentialsSet []*Credentials
|
||||
)
|
||||
|
||||
// Walk iterates through every slice item and calls w(Credentials) err
|
||||
//
|
||||
// This function is auto-generated.
|
||||
func (set CredentialsSet) Walk(w func(*Credentials) error) (err error) {
|
||||
for i := range set {
|
||||
if err = w(set[i]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Filter iterates through every slice item, calls f(Credentials) (bool, err) and return filtered slice
|
||||
//
|
||||
// This function is auto-generated.
|
||||
func (set CredentialsSet) Filter(f func(*Credentials) (bool, error)) (out CredentialsSet, err error) {
|
||||
var ok bool
|
||||
out = CredentialsSet{}
|
||||
for i := range set {
|
||||
if ok, err = f(set[i]); err != nil {
|
||||
return
|
||||
} else if ok {
|
||||
out = append(out, set[i])
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FindByID finds items from slice by its ID property
|
||||
//
|
||||
// This function is auto-generated.
|
||||
func (set CredentialsSet) FindByID(ID uint64) *Credentials {
|
||||
for i := range set {
|
||||
if set[i].ID == ID {
|
||||
return set[i]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IDs returns a slice of uint64s from all items in the set
|
||||
//
|
||||
// This function is auto-generated.
|
||||
func (set CredentialsSet) IDs() (IDs []uint64) {
|
||||
IDs = make([]uint64, len(set))
|
||||
|
||||
for i := range set {
|
||||
IDs[i] = set[i].ID
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
39
system/types/credentials.go
Normal file
39
system/types/credentials.go
Normal file
@ -0,0 +1,39 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx/types"
|
||||
)
|
||||
|
||||
type (
|
||||
Credentials struct {
|
||||
ID uint64 `json:"credentialsID,string" db:"id"`
|
||||
OwnerID uint64 `json:"ownerID,string" db:"rel_owner"`
|
||||
Label string `json:"label" db:"label"`
|
||||
Kind CredentialsKind `json:"kind" db:"kind"`
|
||||
Credentials string `json:"-" db:"credentials"`
|
||||
Meta types.JSONText `json:"-" db:"meta"`
|
||||
ExpiresAt *time.Time `json:"expiresAt,omitempty" db:"expires_at"`
|
||||
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
|
||||
DeletedAt *time.Time `json:"deletedAt,omitempty" db:"deleted_at"`
|
||||
}
|
||||
|
||||
CredentialsKind string
|
||||
)
|
||||
|
||||
const (
|
||||
// Use as a password for users or as API secret for bots (and credentials-id as a key) as a value for "credentials"
|
||||
CredentialsKindHash CredentialsKind = "hash"
|
||||
|
||||
// Identity (profile-id) stored under "credentials"
|
||||
CredentialsKindFacebook CredentialsKind = "facebook"
|
||||
CredentialsKindGPlus CredentialsKind = "gplus"
|
||||
CredentialsKindGitHub CredentialsKind = "github"
|
||||
CredentialsKindLinkedin CredentialsKind = "linkedin"
|
||||
// CredentialsKindSatosa CredentialsKind = "satosa"
|
||||
)
|
||||
|
||||
func (u *Credentials) Valid() bool {
|
||||
return u.ID > 0 && (u.ExpiresAt == nil || u.ExpiresAt.Before(time.Now())) && u.DeletedAt == nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user