upd(all): merge fix/coupling-issues-stage-1
This commit is contained in:
85
system/internal/repository/application.go
Normal file
85
system/internal/repository/application.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
ApplicationRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) ApplicationRepository
|
||||
|
||||
FindByID(id uint64) (*types.Application, error)
|
||||
Find() (types.ApplicationSet, error)
|
||||
|
||||
Create(mod *types.Application) (*types.Application, error)
|
||||
Update(mod *types.Application) (*types.Application, error)
|
||||
|
||||
DeleteByID(id uint64) error
|
||||
}
|
||||
|
||||
application struct {
|
||||
*repository
|
||||
|
||||
// sql table reference
|
||||
applications string
|
||||
members string
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
sqlApplicationColumns = "id, rel_owner, name, enabled, unify, created_at, updated_at, deleted_at"
|
||||
sqlApplicationScope = "deleted_at IS NULL"
|
||||
|
||||
ErrApplicationNotFound = repositoryError("ApplicationNotFound")
|
||||
)
|
||||
|
||||
func Application(ctx context.Context, db *factory.DB) ApplicationRepository {
|
||||
return (&application{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *application) With(ctx context.Context, db *factory.DB) ApplicationRepository {
|
||||
return &application{
|
||||
repository: r.repository.With(ctx, db),
|
||||
applications: "sys_application",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *application) FindByID(id uint64) (*types.Application, error) {
|
||||
sql := "SELECT " + sqlApplicationColumns + " FROM " + r.applications + " WHERE id = ? AND " + sqlApplicationScope
|
||||
mod := &types.Application{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrApplicationNotFound)
|
||||
}
|
||||
|
||||
func (r *application) Find() (types.ApplicationSet, error) {
|
||||
rval := make([]*types.Application, 0)
|
||||
params := make([]interface{}, 0)
|
||||
|
||||
sql := "SELECT " + sqlApplicationColumns + " FROM " + r.applications + " WHERE " + sqlApplicationScope
|
||||
|
||||
sql += " ORDER BY id ASC"
|
||||
|
||||
return rval, r.db().Select(&rval, sql, params...)
|
||||
}
|
||||
|
||||
func (r *application) Create(mod *types.Application) (*types.Application, error) {
|
||||
mod.ID = factory.Sonyflake.NextID()
|
||||
mod.CreatedAt = time.Now()
|
||||
|
||||
return mod, r.db().Insert(r.applications, mod)
|
||||
}
|
||||
|
||||
func (r *application) Update(mod *types.Application) (*types.Application, error) {
|
||||
mod.UpdatedAt = timeNowPtr()
|
||||
|
||||
return mod, r.db().Replace(r.applications, mod)
|
||||
}
|
||||
|
||||
func (r *application) DeleteByID(id uint64) error {
|
||||
return r.updateColumnByID(r.applications, "deleted_at", time.Now(), id)
|
||||
}
|
||||
67
system/internal/repository/applications_test.go
Normal file
67
system/internal/repository/applications_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
|
||||
. "github.com/crusttech/crust/internal/test"
|
||||
)
|
||||
|
||||
func TestApplication(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
|
||||
// Create application repository.
|
||||
crepo := Application(context.Background(), db)
|
||||
|
||||
// Run tests in transaction to maintain DB state.
|
||||
Error(t, db.Transaction(func() error {
|
||||
db.Delete("sys_application", "1=1")
|
||||
|
||||
app := &types.Application{
|
||||
Name: "created",
|
||||
Enabled: true,
|
||||
OwnerID: 1,
|
||||
Unify: &types.ApplicationUnify{
|
||||
Name: "created",
|
||||
Listed: true,
|
||||
Order: 1,
|
||||
Icon: "...ico",
|
||||
},
|
||||
}
|
||||
|
||||
app, err := crepo.Create(app)
|
||||
NoError(t, err, "Application.Create error: %+v", err)
|
||||
Assert(t, app.Valid(), "Expecting application to be valid after creation")
|
||||
Assert(t, app.Name == "created", "Expecting application name to be set, got %q", app.Name)
|
||||
Assert(t, app.Enabled, "Expecting application to be enabled")
|
||||
Assert(t, app.Unify.Name == "created", "Expecting application name to be set in unify, got %q", app.Name)
|
||||
Assert(t, app.Unify.Listed, "Expecting application to be listed in unify")
|
||||
Assert(t, app.Unify.Order == 1, "Expecting application name to have order val 1")
|
||||
|
||||
app.Name = "updated"
|
||||
app.Enabled = false
|
||||
app.Unify.Name = "updated"
|
||||
app.Unify.Listed = false
|
||||
app, err = crepo.Update(app)
|
||||
|
||||
NoError(t, err, "Application.Create error: %+v", err)
|
||||
Assert(t, err == nil, "Application.Create error: %+v", err)
|
||||
Assert(t, app.Name == "updated", "Expecting application name to be updated")
|
||||
Assert(t, !app.Enabled, "Expecting application to be disabled")
|
||||
Assert(t, app.Unify.Name == "updated", "Expecting application name to be updated in unify")
|
||||
Assert(t, !app.Unify.Listed, "Expecting application to be unlisted in unify")
|
||||
|
||||
return errors.New("Rollback")
|
||||
}), "expected rollback error")
|
||||
|
||||
}
|
||||
100
system/internal/repository/credentials.go
Normal file
100
system/internal/repository/credentials.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
CredentialsRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) CredentialsRepository
|
||||
|
||||
FindByID(ID uint64) (*types.Credentials, error)
|
||||
FindByCredentials(kind types.CredentialsKind, credentials string) (cc types.CredentialsSet, err error)
|
||||
FindByKind(ownerID uint64, kind types.CredentialsKind) (cc types.CredentialsSet, err error)
|
||||
FindByOwnerID(ownerID uint64) (cc types.CredentialsSet, err error)
|
||||
Find() (cc types.CredentialsSet, err error)
|
||||
|
||||
Create(c *types.Credentials) (*types.Credentials, error)
|
||||
DeleteByID(id uint64) error
|
||||
}
|
||||
|
||||
credentials struct {
|
||||
*repository
|
||||
|
||||
// sql table reference
|
||||
tblname string
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
sqlCredentialsColumns = "id, rel_owner, kind, label, credentials, meta, expires_at, " +
|
||||
"created_at, deleted_at"
|
||||
sqlCredentialsScope = "deleted_at IS NULL"
|
||||
sqlCredentialsSelect = "SELECT " + sqlCredentialsColumns + " FROM %s WHERE " + sqlCredentialsScope
|
||||
|
||||
ErrCredentialsNotFound = repositoryError("CredentialsNotFound")
|
||||
)
|
||||
|
||||
func Credentials(ctx context.Context, db *factory.DB) CredentialsRepository {
|
||||
return (&credentials{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *credentials) With(ctx context.Context, db *factory.DB) CredentialsRepository {
|
||||
return &credentials{
|
||||
repository: r.repository.With(ctx, db),
|
||||
tblname: "sys_credentials",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *credentials) FindByID(ID uint64) (*types.Credentials, error) {
|
||||
sql := fmt.Sprintf(sqlCredentialsSelect, r.tblname) + " AND id = ?"
|
||||
mod := &types.Credentials{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, ID), mod.ID > 0, ErrCredentialsNotFound)
|
||||
}
|
||||
|
||||
func (r *credentials) FindByCredentials(kind types.CredentialsKind, credentials string) (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect+" AND kind = ? AND credentials = ?", r.tblname),
|
||||
kind,
|
||||
credentials)
|
||||
}
|
||||
|
||||
func (r *credentials) FindByKind(ownerID uint64, kind types.CredentialsKind) (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect+" AND rel_owner = ? AND kind = ?", r.tblname),
|
||||
ownerID,
|
||||
kind)
|
||||
}
|
||||
|
||||
func (r *credentials) FindByOwnerID(ownerID uint64) (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect+" AND rel_owner = ?", r.tblname),
|
||||
ownerID)
|
||||
}
|
||||
|
||||
func (r *credentials) Find() (cc types.CredentialsSet, err error) {
|
||||
return r.fetchSet(
|
||||
fmt.Sprintf(sqlCredentialsSelect, r.tblname))
|
||||
}
|
||||
|
||||
func (r *credentials) fetchSet(sql string, args ...interface{}) (cc types.CredentialsSet, err error) {
|
||||
cc = types.CredentialsSet{}
|
||||
return cc, r.db().Select(&cc, sql, args...)
|
||||
}
|
||||
|
||||
func (r *credentials) Create(c *types.Credentials) (*types.Credentials, error) {
|
||||
c.ID = factory.Sonyflake.NextID()
|
||||
c.CreatedAt = time.Now()
|
||||
return c, r.db().Insert(r.tblname, c)
|
||||
}
|
||||
|
||||
func (r *credentials) DeleteByID(id uint64) error {
|
||||
return r.updateColumnByID(r.tblname, "deleted_at", time.Now(), id)
|
||||
}
|
||||
59
system/internal/repository/credentials_test.go
Normal file
59
system/internal/repository/credentials_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
|
||||
. "github.com/crusttech/crust/internal/test"
|
||||
)
|
||||
|
||||
func TestCredentials(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
|
||||
// Create credentials repository.
|
||||
crepo := Credentials(context.Background(), db)
|
||||
|
||||
// Run tests in transaction to maintain DB state.
|
||||
Error(t, db.Transaction(func() error {
|
||||
db.Delete("sys_credentials", "1=1")
|
||||
|
||||
cc := types.CredentialsSet{
|
||||
&types.Credentials{OwnerID: 10000, Kind: types.CredentialsKindLinkedin, Credentials: "linkedin-profile-id"},
|
||||
&types.Credentials{OwnerID: 10000, Kind: types.CredentialsKindGPlus, Credentials: "gplus-profile-id"},
|
||||
&types.Credentials{OwnerID: 20000, Kind: types.CredentialsKindFacebook, Credentials: "facebook-profile-id"},
|
||||
}
|
||||
|
||||
for _, c := range cc {
|
||||
cNew, err := crepo.Create(c)
|
||||
assert(t, err == nil, "Credentials.Create error: %+v", err)
|
||||
assert(t, c.ID > 0, "Expecting credentials to have a valid ID")
|
||||
assert(t, c.Valid(), "Expecting credentials to be valid after creation")
|
||||
|
||||
_, err = crepo.FindByID(cNew.ID)
|
||||
assert(t, err == nil, "Credentials.FindByID error: %+v", err)
|
||||
|
||||
{
|
||||
r, err := crepo.FindByKind(c.OwnerID, c.Kind)
|
||||
assert(t, err == nil, "Credentials.FindByKind error: %+v", err)
|
||||
assert(t, len(r) == 1, "Expecting exactly 1 result from FindByKind, got: %v", len(r))
|
||||
}
|
||||
|
||||
{
|
||||
r, err := crepo.FindByCredentials(c.Kind, c.Credentials)
|
||||
assert(t, err == nil, "Credentials.FindByKind error: %+v", err)
|
||||
assert(t, len(r) == 1, "Expecting exactly 1 result from FindByCredentials, got: %v", len(r))
|
||||
}
|
||||
}
|
||||
return errors.New("Rollback")
|
||||
}), "expected rollback error")
|
||||
}
|
||||
18
system/internal/repository/error.go
Normal file
18
system/internal/repository/error.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package repository
|
||||
|
||||
type (
|
||||
repositoryError string
|
||||
)
|
||||
|
||||
const (
|
||||
ErrDatabaseError = repositoryError("DatabaseError")
|
||||
ErrNotImplemented = repositoryError("NotImplemented")
|
||||
)
|
||||
|
||||
func (e repositoryError) Error() string {
|
||||
return e.String()
|
||||
}
|
||||
|
||||
func (e repositoryError) String() string {
|
||||
return "crust.auth.repository." + string(e)
|
||||
}
|
||||
54
system/internal/repository/main_test.go
Normal file
54
system/internal/repository/main_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/namsral/flag"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
systemMigrate "github.com/crusttech/crust/system/db"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// @todo this is a very optimistic initialization, make it more robust
|
||||
godotenv.Load("../../.env")
|
||||
|
||||
prefix := "system"
|
||||
dsn := ""
|
||||
|
||||
p := func(s string) string {
|
||||
return prefix + "-" + s
|
||||
}
|
||||
|
||||
flag.StringVar(&dsn, p("db-dsn"), "crust:crust@tcp(db1:3306)/crust?collation=utf8mb4_general_ci", "DSN for database connection")
|
||||
flag.Parse()
|
||||
|
||||
factory.Database.Add("default", dsn)
|
||||
factory.Database.Add("system", dsn)
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
db.Profiler = &factory.Database.ProfilerStdout
|
||||
|
||||
// migrate database schema
|
||||
if err := systemMigrate.Migrate(db); err != nil {
|
||||
log.Printf("Error running migrations: %+v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func assert(t *testing.T, ok bool, format string, args ...interface{}) bool {
|
||||
if !ok {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
caller := fmt.Sprintf("\nAsserted at:%s:%d", file, line)
|
||||
|
||||
t.Fatalf(format+caller, args...)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
139
system/internal/repository/mocks/credentials.go
Normal file
139
system/internal/repository/mocks/credentials.go
Normal file
@@ -0,0 +1,139 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/repository/credentials.go
|
||||
|
||||
// Package repository is a generated GoMock package.
|
||||
package repository
|
||||
|
||||
import (
|
||||
context "context"
|
||||
repository "github.com/crusttech/crust/system/repository"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
factory "github.com/titpetric/factory"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockCredentialsRepository is a mock of CredentialsRepository interface
|
||||
type MockCredentialsRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCredentialsRepositoryMockRecorder
|
||||
}
|
||||
|
||||
// MockCredentialsRepositoryMockRecorder is the mock recorder for MockCredentialsRepository
|
||||
type MockCredentialsRepositoryMockRecorder struct {
|
||||
mock *MockCredentialsRepository
|
||||
}
|
||||
|
||||
// NewMockCredentialsRepository creates a new mock instance
|
||||
func NewMockCredentialsRepository(ctrl *gomock.Controller) *MockCredentialsRepository {
|
||||
mock := &MockCredentialsRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockCredentialsRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCredentialsRepository) EXPECT() *MockCredentialsRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockCredentialsRepository) With(ctx context.Context, db *factory.DB) repository.CredentialsRepository {
|
||||
ret := m.ctrl.Call(m, "With", ctx, db)
|
||||
ret0, _ := ret[0].(repository.CredentialsRepository)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) With(ctx, db interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockCredentialsRepository)(nil).With), ctx, db)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockCredentialsRepository) FindByID(ID uint64) (*types.Credentials, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", ID)
|
||||
ret0, _ := ret[0].(*types.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByID(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByID), ID)
|
||||
}
|
||||
|
||||
// FindByCredentials mocks base method
|
||||
func (m *MockCredentialsRepository) FindByCredentials(kind types.CredentialsKind, credentials string) (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "FindByCredentials", kind, credentials)
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByCredentials indicates an expected call of FindByCredentials
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByCredentials(kind, credentials interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByCredentials", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByCredentials), kind, credentials)
|
||||
}
|
||||
|
||||
// FindByKind mocks base method
|
||||
func (m *MockCredentialsRepository) FindByKind(ownerID uint64, kind types.CredentialsKind) (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "FindByKind", ownerID, kind)
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByKind indicates an expected call of FindByKind
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByKind(ownerID, kind interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByKind", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByKind), ownerID, kind)
|
||||
}
|
||||
|
||||
// FindByOwnerID mocks base method
|
||||
func (m *MockCredentialsRepository) FindByOwnerID(ownerID uint64) (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "FindByOwnerID", ownerID)
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByOwnerID indicates an expected call of FindByOwnerID
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) FindByOwnerID(ownerID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByOwnerID", reflect.TypeOf((*MockCredentialsRepository)(nil).FindByOwnerID), ownerID)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockCredentialsRepository) Find() (types.CredentialsSet, error) {
|
||||
ret := m.ctrl.Call(m, "Find")
|
||||
ret0, _ := ret[0].(types.CredentialsSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) Find() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockCredentialsRepository)(nil).Find))
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockCredentialsRepository) Create(c *types.Credentials) (*types.Credentials, error) {
|
||||
ret := m.ctrl.Call(m, "Create", c)
|
||||
ret0, _ := ret[0].(*types.Credentials)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) Create(c interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockCredentialsRepository)(nil).Create), c)
|
||||
}
|
||||
|
||||
// DeleteByID mocks base method
|
||||
func (m *MockCredentialsRepository) DeleteByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "DeleteByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteByID indicates an expected call of DeleteByID
|
||||
func (mr *MockCredentialsRepositoryMockRecorder) DeleteByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteByID", reflect.TypeOf((*MockCredentialsRepository)(nil).DeleteByID), id)
|
||||
}
|
||||
193
system/internal/repository/mocks/user.go
Normal file
193
system/internal/repository/mocks/user.go
Normal file
@@ -0,0 +1,193 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/repository/user.go
|
||||
|
||||
// Package repository is a generated GoMock package.
|
||||
package repository
|
||||
|
||||
import (
|
||||
context "context"
|
||||
repository "github.com/crusttech/crust/system/repository"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
factory "github.com/titpetric/factory"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockUserRepository is a mock of UserRepository interface
|
||||
type MockUserRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUserRepositoryMockRecorder
|
||||
}
|
||||
|
||||
// MockUserRepositoryMockRecorder is the mock recorder for MockUserRepository
|
||||
type MockUserRepositoryMockRecorder struct {
|
||||
mock *MockUserRepository
|
||||
}
|
||||
|
||||
// NewMockUserRepository creates a new mock instance
|
||||
func NewMockUserRepository(ctrl *gomock.Controller) *MockUserRepository {
|
||||
mock := &MockUserRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockUserRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockUserRepository) EXPECT() *MockUserRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockUserRepository) With(ctx context.Context, db *factory.DB) repository.UserRepository {
|
||||
ret := m.ctrl.Call(m, "With", ctx, db)
|
||||
ret0, _ := ret[0].(repository.UserRepository)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockUserRepositoryMockRecorder) With(ctx, db interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockUserRepository)(nil).With), ctx, db)
|
||||
}
|
||||
|
||||
// FindByEmail mocks base method
|
||||
func (m *MockUserRepository) FindByEmail(email string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByEmail", email)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByEmail indicates an expected call of FindByEmail
|
||||
func (mr *MockUserRepositoryMockRecorder) FindByEmail(email interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByEmail", reflect.TypeOf((*MockUserRepository)(nil).FindByEmail), email)
|
||||
}
|
||||
|
||||
// FindByUsername mocks base method
|
||||
func (m *MockUserRepository) FindByUsername(username string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByUsername", username)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByUsername indicates an expected call of FindByUsername
|
||||
func (mr *MockUserRepositoryMockRecorder) FindByUsername(username interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByUsername", reflect.TypeOf((*MockUserRepository)(nil).FindByUsername), username)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockUserRepository) FindByID(id uint64) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", id)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockUserRepositoryMockRecorder) FindByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockUserRepository)(nil).FindByID), id)
|
||||
}
|
||||
|
||||
// FindByIDs mocks base method
|
||||
func (m *MockUserRepository) FindByIDs(id ...uint64) (types.UserSet, error) {
|
||||
varargs := []interface{}{}
|
||||
for _, a := range id {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindByIDs", varargs...)
|
||||
ret0, _ := ret[0].(types.UserSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByIDs indicates an expected call of FindByIDs
|
||||
func (mr *MockUserRepositoryMockRecorder) FindByIDs(id ...interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByIDs", reflect.TypeOf((*MockUserRepository)(nil).FindByIDs), id...)
|
||||
}
|
||||
|
||||
// FindBySatosaID mocks base method
|
||||
func (m *MockUserRepository) FindBySatosaID(id string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindBySatosaID", id)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindBySatosaID indicates an expected call of FindBySatosaID
|
||||
func (mr *MockUserRepositoryMockRecorder) FindBySatosaID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindBySatosaID", reflect.TypeOf((*MockUserRepository)(nil).FindBySatosaID), id)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockUserRepository) Find(filter *types.UserFilter) ([]*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Find", filter)
|
||||
ret0, _ := ret[0].([]*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockUserRepositoryMockRecorder) Find(filter interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockUserRepository)(nil).Find), filter)
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockUserRepository) Create(mod *types.User) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Create", mod)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockUserRepositoryMockRecorder) Create(mod interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockUserRepository)(nil).Create), mod)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockUserRepository) Update(mod *types.User) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Update", mod)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockUserRepositoryMockRecorder) Update(mod interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockUserRepository)(nil).Update), mod)
|
||||
}
|
||||
|
||||
// SuspendByID mocks base method
|
||||
func (m *MockUserRepository) SuspendByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "SuspendByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SuspendByID indicates an expected call of SuspendByID
|
||||
func (mr *MockUserRepositoryMockRecorder) SuspendByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SuspendByID", reflect.TypeOf((*MockUserRepository)(nil).SuspendByID), id)
|
||||
}
|
||||
|
||||
// UnsuspendByID mocks base method
|
||||
func (m *MockUserRepository) UnsuspendByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "UnsuspendByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnsuspendByID indicates an expected call of UnsuspendByID
|
||||
func (mr *MockUserRepositoryMockRecorder) UnsuspendByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsuspendByID", reflect.TypeOf((*MockUserRepository)(nil).UnsuspendByID), id)
|
||||
}
|
||||
|
||||
// DeleteByID mocks base method
|
||||
func (m *MockUserRepository) DeleteByID(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "DeleteByID", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteByID indicates an expected call of DeleteByID
|
||||
func (mr *MockUserRepositoryMockRecorder) DeleteByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteByID", reflect.TypeOf((*MockUserRepository)(nil).DeleteByID), id)
|
||||
}
|
||||
97
system/internal/repository/organisation.go
Normal file
97
system/internal/repository/organisation.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
OrganisationRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) OrganisationRepository
|
||||
|
||||
FindOrganisationByID(id uint64) (*types.Organisation, error)
|
||||
FindOrganisations(filter *types.OrganisationFilter) ([]*types.Organisation, error)
|
||||
CreateOrganisation(mod *types.Organisation) (*types.Organisation, error)
|
||||
UpdateOrganisation(mod *types.Organisation) (*types.Organisation, error)
|
||||
ArchiveOrganisationByID(id uint64) error
|
||||
UnarchiveOrganisationByID(id uint64) error
|
||||
DeleteOrganisationByID(id uint64) error
|
||||
}
|
||||
|
||||
organisation struct {
|
||||
*repository
|
||||
|
||||
// sql table reference
|
||||
organisations string
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
sqlOrganisationScope = "deleted_at IS NULL AND archived_at IS NULL"
|
||||
|
||||
ErrOrganisationNotFound = repositoryError("OrganisationNotFound")
|
||||
)
|
||||
|
||||
func Organisation(ctx context.Context, db *factory.DB) OrganisationRepository {
|
||||
return (&organisation{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *organisation) With(ctx context.Context, db *factory.DB) OrganisationRepository {
|
||||
return &organisation{
|
||||
repository: r.repository.With(ctx, db),
|
||||
organisations: "sys_organisation",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *organisation) FindOrganisationByID(id uint64) (*types.Organisation, error) {
|
||||
sql := "SELECT * FROM " + r.organisations + " WHERE id = ? AND " + sqlOrganisationScope
|
||||
mod := &types.Organisation{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrOrganisationNotFound)
|
||||
}
|
||||
|
||||
func (r *organisation) FindOrganisations(filter *types.OrganisationFilter) ([]*types.Organisation, error) {
|
||||
rval := make([]*types.Organisation, 0)
|
||||
params := make([]interface{}, 0)
|
||||
sql := "SELECT * FROM " + r.organisations + " WHERE " + sqlOrganisationScope
|
||||
|
||||
if filter != nil {
|
||||
if filter.Query != "" {
|
||||
sql += " AND name LIKE ?"
|
||||
params = append(params, filter.Query+"%")
|
||||
}
|
||||
}
|
||||
|
||||
sql += " ORDER BY name ASC"
|
||||
|
||||
return rval, r.db().Select(&rval, sql, params...)
|
||||
}
|
||||
|
||||
func (r *organisation) CreateOrganisation(mod *types.Organisation) (*types.Organisation, error) {
|
||||
mod.ID = factory.Sonyflake.NextID()
|
||||
mod.CreatedAt = time.Now()
|
||||
|
||||
return mod, r.db().Insert(r.organisations, mod)
|
||||
}
|
||||
|
||||
func (r *organisation) UpdateOrganisation(mod *types.Organisation) (*types.Organisation, error) {
|
||||
mod.UpdatedAt = timeNowPtr()
|
||||
|
||||
return mod, r.db().Replace(r.organisations, mod)
|
||||
}
|
||||
|
||||
func (r *organisation) ArchiveOrganisationByID(id uint64) error {
|
||||
return r.updateColumnByID(r.organisations, "archived_at", time.Now(), id)
|
||||
}
|
||||
|
||||
func (r *organisation) UnarchiveOrganisationByID(id uint64) error {
|
||||
return r.updateColumnByID(r.organisations, "archived_at", nil, id)
|
||||
}
|
||||
|
||||
func (r *organisation) DeleteOrganisationByID(id uint64) error {
|
||||
return r.updateColumnByID(r.organisations, "deleted_at", time.Now(), id)
|
||||
}
|
||||
72
system/internal/repository/organisation_test.go
Normal file
72
system/internal/repository/organisation_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
|
||||
. "github.com/crusttech/crust/internal/test"
|
||||
)
|
||||
|
||||
func TestOrganisation(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
|
||||
// Run tests in transaction to maintain DB state.
|
||||
Error(t, db.Transaction(func() error {
|
||||
rpo := Organisation(context.Background(), db)
|
||||
org := &types.Organisation{
|
||||
Name: "Test organisation v1",
|
||||
}
|
||||
|
||||
{
|
||||
oa, err := rpo.CreateOrganisation(org)
|
||||
assert(t, err == nil, "CreateOrganisation error: %+v", err)
|
||||
assert(t, oa.Name == org.Name, "Changes were not stored")
|
||||
}
|
||||
|
||||
{
|
||||
org.Name = "Test organisation v2"
|
||||
|
||||
oa, err := rpo.UpdateOrganisation(org)
|
||||
assert(t, err == nil, "UpdateOrganisation error: %+v", err)
|
||||
assert(t, oa.Name == org.Name, "Changes were not stored")
|
||||
}
|
||||
|
||||
{
|
||||
oa, err := rpo.FindOrganisationByID(org.ID)
|
||||
assert(t, err == nil, "FindOrganisationByID error: %+v", err)
|
||||
assert(t, oa.Name == org.Name, "Changes were not stored")
|
||||
}
|
||||
|
||||
{
|
||||
oa, err := rpo.FindOrganisations(&types.OrganisationFilter{Query: org.Name})
|
||||
assert(t, err == nil, "FindOrganisations error: %+v", err)
|
||||
assert(t, len(oa) != 0, "No results found")
|
||||
}
|
||||
|
||||
{
|
||||
err := rpo.ArchiveOrganisationByID(org.ID)
|
||||
assert(t, err == nil, "ArchiveOrganisationByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
err := rpo.UnarchiveOrganisationByID(org.ID)
|
||||
assert(t, err == nil, "UnarchiveOrganisationByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
err := rpo.DeleteOrganisationByID(org.ID)
|
||||
assert(t, err == nil, "DeleteOrganisationByID error: %+v", err)
|
||||
}
|
||||
return errors.New("Rollback")
|
||||
}), "expected rollback error")
|
||||
}
|
||||
39
system/internal/repository/repository.go
Normal file
39
system/internal/repository/repository.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/titpetric/factory"
|
||||
)
|
||||
|
||||
type (
|
||||
repository struct {
|
||||
ctx context.Context
|
||||
dbh *factory.DB
|
||||
}
|
||||
)
|
||||
|
||||
// DB produces a contextual DB handle
|
||||
func DB(ctx context.Context) *factory.DB {
|
||||
return factory.Database.MustGet("system").With(ctx)
|
||||
}
|
||||
|
||||
// With updates repository and database contexts
|
||||
func (r *repository) With(ctx context.Context, db *factory.DB) *repository {
|
||||
return &repository{
|
||||
ctx: ctx,
|
||||
dbh: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Context returns current active repository context
|
||||
func (r *repository) Context() context.Context {
|
||||
return r.ctx
|
||||
}
|
||||
|
||||
// db returns context-aware db handle
|
||||
func (r *repository) db() *factory.DB {
|
||||
if r.dbh != nil {
|
||||
return r.dbh
|
||||
}
|
||||
return DB(r.ctx)
|
||||
}
|
||||
24
system/internal/repository/repository_test.go
Normal file
24
system/internal/repository/repository_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func tx(t *testing.T, f func() error) {
|
||||
db := DB(context.Background())
|
||||
|
||||
if err := db.Begin(); err != nil {
|
||||
t.Errorf("Could not begin transaction: %v", err)
|
||||
|
||||
}
|
||||
|
||||
if err := f(); err != nil {
|
||||
t.Errorf("Test transaction resulted in an error: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Rollback(); err != nil {
|
||||
t.Errorf("Could not rollback transaction: %v", err)
|
||||
}
|
||||
}
|
||||
164
system/internal/repository/role.go
Normal file
164
system/internal/repository/role.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
RoleRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) RoleRepository
|
||||
|
||||
FindByID(id uint64) (*types.Role, error)
|
||||
FindByMemberID(userID uint64) ([]*types.Role, error)
|
||||
Find(filter *types.RoleFilter) ([]*types.Role, error)
|
||||
|
||||
Create(mod *types.Role) (*types.Role, error)
|
||||
Update(mod *types.Role) (*types.Role, error)
|
||||
|
||||
ArchiveByID(id uint64) error
|
||||
UnarchiveByID(id uint64) error
|
||||
DeleteByID(id uint64) error
|
||||
|
||||
MergeByID(id, targetRoleID uint64) error
|
||||
MoveByID(id, targetOrganisationID uint64) error
|
||||
|
||||
MemberFindByRoleID(roleID uint64) ([]*types.RoleMember, error)
|
||||
MemberAddByID(roleID, userID uint64) error
|
||||
MemberRemoveByID(roleID, userID uint64) error
|
||||
}
|
||||
|
||||
role struct {
|
||||
*repository
|
||||
|
||||
// sql table reference
|
||||
roles string
|
||||
members string
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
sqlRoleScope = "deleted_at IS NULL AND archived_at IS NULL"
|
||||
|
||||
ErrRoleNotFound = repositoryError("RoleNotFound")
|
||||
)
|
||||
|
||||
func Role(ctx context.Context, db *factory.DB) RoleRepository {
|
||||
return (&role{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *role) With(ctx context.Context, db *factory.DB) RoleRepository {
|
||||
return &role{
|
||||
repository: r.repository.With(ctx, db),
|
||||
roles: "sys_role",
|
||||
members: "sys_role_member",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *role) FindByID(id uint64) (*types.Role, error) {
|
||||
sql := "SELECT * FROM " + r.roles + " WHERE id = ? AND " + sqlRoleScope
|
||||
mod := &types.Role{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrRoleNotFound)
|
||||
}
|
||||
|
||||
func (r *role) FindByMemberID(userID uint64) ([]*types.Role, error) {
|
||||
ids := make([]uint64, 0)
|
||||
params := make([]interface{}, 0)
|
||||
|
||||
sql := "SELECT DISTINCT rel_role FROM " + r.members + " "
|
||||
sql += "WHERE rel_user = ?"
|
||||
params = append(params, userID)
|
||||
|
||||
if err := r.db().Select(&ids, sql, params...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rval := make([]*types.Role, 0)
|
||||
for _, id := range ids {
|
||||
mod, err := r.FindByID(id)
|
||||
if err != nil && err != ErrRoleNotFound {
|
||||
return nil, err
|
||||
}
|
||||
rval = append(rval, mod)
|
||||
}
|
||||
|
||||
return rval, nil
|
||||
}
|
||||
|
||||
func (r *role) Find(filter *types.RoleFilter) ([]*types.Role, error) {
|
||||
rval := make([]*types.Role, 0)
|
||||
params := make([]interface{}, 0)
|
||||
|
||||
sql := "SELECT * FROM " + r.roles + " WHERE " + sqlRoleScope
|
||||
|
||||
if filter != nil {
|
||||
if filter.Query != "" {
|
||||
sql += " AND name LIKE ?"
|
||||
params = append(params, filter.Query+"%")
|
||||
}
|
||||
}
|
||||
|
||||
sql += " ORDER BY name ASC"
|
||||
|
||||
return rval, r.db().Select(&rval, sql, params...)
|
||||
}
|
||||
|
||||
func (r *role) Create(mod *types.Role) (*types.Role, error) {
|
||||
mod.ID = factory.Sonyflake.NextID()
|
||||
mod.CreatedAt = time.Now()
|
||||
|
||||
return mod, r.db().Insert(r.roles, mod)
|
||||
}
|
||||
|
||||
func (r *role) Update(mod *types.Role) (*types.Role, error) {
|
||||
mod.UpdatedAt = timeNowPtr()
|
||||
|
||||
return mod, r.db().Replace(r.roles, mod)
|
||||
}
|
||||
|
||||
func (r *role) ArchiveByID(id uint64) error {
|
||||
return r.updateColumnByID(r.roles, "archived_at", time.Now(), id)
|
||||
}
|
||||
|
||||
func (r *role) UnarchiveByID(id uint64) error {
|
||||
return r.updateColumnByID(r.roles, "archived_at", nil, id)
|
||||
}
|
||||
|
||||
func (r *role) DeleteByID(id uint64) error {
|
||||
return r.updateColumnByID(r.roles, "deleted_at", time.Now(), id)
|
||||
}
|
||||
|
||||
func (r *role) MergeByID(id, targetRoleID uint64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (r *role) MoveByID(id, targetOrganisationID uint64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (r *role) MemberFindByRoleID(roleID uint64) (mm []*types.RoleMember, err error) {
|
||||
rval := make([]*types.RoleMember, 0)
|
||||
sql := "SELECT * FROM " + r.members + " WHERE rel_role = ?"
|
||||
return rval, r.db().Select(&rval, sql, roleID)
|
||||
}
|
||||
|
||||
func (r *role) MemberAddByID(roleID, userID uint64) error {
|
||||
mod := &types.RoleMember{
|
||||
RoleID: roleID,
|
||||
UserID: userID,
|
||||
}
|
||||
return r.db().Replace(r.members, mod)
|
||||
}
|
||||
|
||||
func (r *role) MemberRemoveByID(roleID, userID uint64) error {
|
||||
mod := &types.RoleMember{
|
||||
RoleID: roleID,
|
||||
UserID: userID,
|
||||
}
|
||||
return r.db().Delete(r.members, mod, "rel_role", "rel_user")
|
||||
}
|
||||
110
system/internal/repository/role_test.go
Normal file
110
system/internal/repository/role_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
|
||||
. "github.com/crusttech/crust/internal/test"
|
||||
)
|
||||
|
||||
func TestRole(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
|
||||
Error(t, db.Transaction(func() error {
|
||||
userRepo := User(context.Background(), db)
|
||||
user := &types.User{
|
||||
Name: "John Role Doe",
|
||||
Username: "johndoe",
|
||||
}
|
||||
user.GeneratePassword("johndoe")
|
||||
|
||||
{
|
||||
u1, err := userRepo.Create(user)
|
||||
assert(t, err == nil, "Owner.Create error: %+v", err)
|
||||
assert(t, user.ID == u1.ID, "Changes were not stored")
|
||||
}
|
||||
|
||||
roleRepo := Role(context.Background(), db)
|
||||
role := &types.Role{
|
||||
Name: "Test role v1",
|
||||
}
|
||||
|
||||
{
|
||||
t1, err := roleRepo.Create(role)
|
||||
assert(t, err == nil, "Role.Create error: %+v", err)
|
||||
assert(t, role.Name == t1.Name, "Changes were not stored")
|
||||
}
|
||||
|
||||
{
|
||||
role.Name = "Test role v2"
|
||||
t1, err := roleRepo.Update(role)
|
||||
assert(t, err == nil, "Role.Update error: %+v", err)
|
||||
assert(t, role.Name == t1.Name, "Changes were not stored")
|
||||
}
|
||||
|
||||
{
|
||||
t1, err := roleRepo.FindByID(role.ID)
|
||||
assert(t, err == nil, "Role.FindByID error: %+v", err)
|
||||
assert(t, role.Name == t1.Name, "Changes were not stored")
|
||||
}
|
||||
|
||||
{
|
||||
aa, err := roleRepo.Find(&types.RoleFilter{Query: role.Name})
|
||||
assert(t, err == nil, "Role.Find error: %+v", err)
|
||||
assert(t, len(aa) > 0, "No results found")
|
||||
}
|
||||
|
||||
{
|
||||
err := roleRepo.ArchiveByID(role.ID)
|
||||
assert(t, err == nil, "Role.ArchiveByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
err := roleRepo.UnarchiveByID(role.ID)
|
||||
assert(t, err == nil, "Role.UnarchiveByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
err := roleRepo.MemberAddByID(role.ID, user.ID)
|
||||
assert(t, err == nil, "Role.MemberAddByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
roles, err := roleRepo.FindByMemberID(user.ID)
|
||||
assert(t, err == nil, "Role.FindByMemberID error: %+v", err)
|
||||
assert(t, len(roles) > 0, "No results found")
|
||||
}
|
||||
|
||||
{
|
||||
roles, err := roleRepo.FindByMemberID(0)
|
||||
assert(t, err == nil, "Role.FindByMemberID error: %+v", err)
|
||||
assert(t, len(roles) == 0, "Results found")
|
||||
}
|
||||
|
||||
{
|
||||
err := roleRepo.MemberRemoveByID(role.ID, user.ID)
|
||||
assert(t, err == nil, "Role.MemberRemoveByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
err := roleRepo.DeleteByID(role.ID)
|
||||
assert(t, err == nil, "Role.DeleteByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
err := userRepo.DeleteByID(user.ID)
|
||||
assert(t, err == nil, "Owner.DeleteByID error: %+v", err)
|
||||
}
|
||||
return errors.New("Rollback")
|
||||
}), "expected rollback error")
|
||||
}
|
||||
63
system/internal/repository/settings.go
Normal file
63
system/internal/repository/settings.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
)
|
||||
|
||||
type (
|
||||
settings struct {
|
||||
*repository
|
||||
|
||||
// sql table reference
|
||||
settings string
|
||||
}
|
||||
|
||||
Settings interface {
|
||||
With(ctx context.Context, db *factory.DB) Settings
|
||||
|
||||
Get(name string, value interface{}) (bool, error)
|
||||
Set(name string, value interface{}) error
|
||||
}
|
||||
)
|
||||
|
||||
func NewSettings(ctx context.Context, db *factory.DB) Settings {
|
||||
return (&settings{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *settings) With(ctx context.Context, db *factory.DB) Settings {
|
||||
return &settings{
|
||||
repository: r.repository.With(ctx, db),
|
||||
settings: "settings",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *settings) Set(name string, value interface{}) error {
|
||||
if jsonValue, err := json.Marshal(value); err != nil {
|
||||
return errors.Wrap(err, "Error marshaling settings value")
|
||||
} else {
|
||||
return r.db().Replace(r.settings, struct {
|
||||
Key string `db:"name"`
|
||||
Val json.RawMessage `db:"value"`
|
||||
}{name, jsonValue})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *settings) Get(name string, value interface{}) (bool, error) {
|
||||
sql := "SELECT value FROM " + r.settings + " WHERE name = ?"
|
||||
|
||||
var stored json.RawMessage
|
||||
|
||||
if err := r.db().Get(&stored, sql, name); err != nil {
|
||||
return false, errors.Wrap(err, "Error reading settings from the database")
|
||||
} else if stored == nil {
|
||||
return false, nil
|
||||
} else if err := json.Unmarshal(stored, value); err != nil {
|
||||
return false, errors.Wrap(err, "Error unmarshaling settings value")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
206
system/internal/repository/user.go
Normal file
206
system/internal/repository/user.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
UserRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) UserRepository
|
||||
|
||||
FindByEmail(email string) (*types.User, error)
|
||||
FindByUsername(username string) (*types.User, error)
|
||||
FindByID(id uint64) (*types.User, error)
|
||||
FindByIDs(id ...uint64) (types.UserSet, error)
|
||||
FindBySatosaID(id string) (*types.User, error)
|
||||
Find(filter *types.UserFilter) ([]*types.User, error)
|
||||
|
||||
Create(mod *types.User) (*types.User, error)
|
||||
Update(mod *types.User) (*types.User, error)
|
||||
|
||||
SuspendByID(id uint64) error
|
||||
UnsuspendByID(id uint64) error
|
||||
DeleteByID(id uint64) error
|
||||
}
|
||||
|
||||
user struct {
|
||||
*repository
|
||||
|
||||
// sql table reference
|
||||
users string
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
sqlUserColumns = "id, email, username, password, name, handle, " +
|
||||
"meta, satosa_id, rel_organisation, " +
|
||||
"created_at, updated_at, suspended_at, deleted_at"
|
||||
sqlUserScope = "suspended_at IS NULL AND deleted_at IS NULL"
|
||||
sqlUserSelect = "SELECT " + sqlUserColumns + " FROM %s WHERE " + sqlUserScope
|
||||
|
||||
ErrUserNotFound = repositoryError("UserNotFound")
|
||||
)
|
||||
|
||||
func User(ctx context.Context, db *factory.DB) UserRepository {
|
||||
return (&user{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *user) With(ctx context.Context, db *factory.DB) UserRepository {
|
||||
return &user{
|
||||
repository: r.repository.With(ctx, db),
|
||||
users: "sys_user",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *user) FindByUsername(username string) (*types.User, error) {
|
||||
sql := fmt.Sprintf(sqlUserSelect, r.users) + " AND username = ?"
|
||||
mod := &types.User{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, username), mod.ID > 0, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func (r *user) FindBySatosaID(satosaID string) (*types.User, error) {
|
||||
sql := fmt.Sprintf(sqlUserSelect, r.users) + " AND satosa_id = ?"
|
||||
mod := &types.User{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, satosaID), mod.ID > 0, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func (r *user) FindByEmail(email string) (*types.User, error) {
|
||||
sql := fmt.Sprintf(sqlUserSelect, r.users) + " AND email = ?"
|
||||
mod := &types.User{}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sql, email), mod.ID > 0, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func (r *user) FindByID(id uint64) (*types.User, error) {
|
||||
sql := fmt.Sprintf(sqlUserSelect, r.users) + " AND id = ?"
|
||||
mod := &types.User{}
|
||||
if err := isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrUserNotFound); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err := r.prepare(mod, "roles")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mod, nil
|
||||
}
|
||||
|
||||
func (r *user) FindByIDs(IDs ...uint64) (uu types.UserSet, err error) {
|
||||
if len(IDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(sqlUserSelect, r.users) + " AND id IN (?)"
|
||||
|
||||
if sql, args, err := sqlx.In(sql, IDs); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return uu, r.db().Select(&uu, sql, args...)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (r *user) Find(filter *types.UserFilter) ([]*types.User, error) {
|
||||
if filter == nil {
|
||||
filter = &types.UserFilter{}
|
||||
}
|
||||
|
||||
rval := make([]*types.User, 0)
|
||||
params := make([]interface{}, 0)
|
||||
sql := fmt.Sprintf(sqlUserSelect, r.users)
|
||||
|
||||
if filter.Query != "" {
|
||||
sql += " AND (username LIKE ?"
|
||||
params = append(params, filter.Query+"%")
|
||||
sql += " OR email LIKE ?"
|
||||
params = append(params, filter.Query+"%")
|
||||
sql += " OR name LIKE ?)"
|
||||
params = append(params, filter.Query+"%")
|
||||
}
|
||||
|
||||
if filter.Email != "" {
|
||||
sql += " AND (email = ?)"
|
||||
params = append(params, filter.Email)
|
||||
}
|
||||
|
||||
if filter.Username != "" {
|
||||
sql += " AND (username = ?)"
|
||||
params = append(params, filter.Username)
|
||||
}
|
||||
|
||||
switch filter.OrderBy {
|
||||
case "updated_at", "createdAt":
|
||||
sql += " ORDER BY updated_at DESC"
|
||||
default:
|
||||
sql += " ORDER BY username ASC"
|
||||
}
|
||||
|
||||
if err := r.db().Select(&rval, sql, params...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.prepareAll(rval, "roles"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rval, nil
|
||||
}
|
||||
|
||||
func (r *user) Create(mod *types.User) (*types.User, error) {
|
||||
mod.ID = factory.Sonyflake.NextID()
|
||||
mod.CreatedAt = time.Now()
|
||||
return mod, r.db().Insert(r.users, mod)
|
||||
}
|
||||
|
||||
func (r *user) Update(mod *types.User) (*types.User, error) {
|
||||
mod.UpdatedAt = timeNowPtr()
|
||||
return mod, r.db().Replace(r.users, mod)
|
||||
}
|
||||
|
||||
func (r *user) SuspendByID(id uint64) error {
|
||||
return r.updateColumnByID(r.users, "suspend_at", time.Now(), id)
|
||||
}
|
||||
|
||||
func (r *user) UnsuspendByID(id uint64) error {
|
||||
return r.updateColumnByID(r.users, "suspend_at", nil, id)
|
||||
}
|
||||
|
||||
func (r *user) DeleteByID(id uint64) error {
|
||||
return r.updateColumnByID(r.users, "deleted_at", time.Now(), id)
|
||||
}
|
||||
|
||||
func (r *user) prepareAll(users []*types.User, fields ...string) error {
|
||||
for _, user := range users {
|
||||
if err := r.prepare(user, fields...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *user) prepare(user *types.User, fields ...string) (err error) {
|
||||
api := Role(r.Context(), r.db())
|
||||
for _, field := range fields {
|
||||
switch field {
|
||||
case "roles":
|
||||
if user.ID > 0 {
|
||||
roles, err := api.FindByMemberID(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.Roles = roles
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
67
system/internal/repository/user_test.go
Normal file
67
system/internal/repository/user_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/types"
|
||||
|
||||
. "github.com/crusttech/crust/internal/test"
|
||||
)
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
|
||||
// Run tests in transaction to maintain DB state.
|
||||
Error(t, db.Transaction(func() error {
|
||||
userRepo := User(context.Background(), db)
|
||||
user := &types.User{
|
||||
Name: "John User Doe",
|
||||
Username: "johndoe",
|
||||
SatosaID: "1234",
|
||||
}
|
||||
user.GeneratePassword("johndoe")
|
||||
|
||||
{
|
||||
uu, err := userRepo.Create(user)
|
||||
assert(t, err == nil, "Owner.Create error: %+v", err)
|
||||
assert(t, user.ID == uu.ID, "Changes were not stored")
|
||||
}
|
||||
|
||||
roleRepo := Role(context.Background(), db)
|
||||
role := &types.Role{
|
||||
Name: "Test role v1",
|
||||
}
|
||||
|
||||
{
|
||||
t1, err := roleRepo.Create(role)
|
||||
assert(t, err == nil, "Role.Create error: %+v", err)
|
||||
assert(t, role.Name == t1.Name, "Changes were not stored")
|
||||
|
||||
err = roleRepo.MemberAddByID(t1.ID, user.ID)
|
||||
assert(t, err == nil, "Role.MemberAddByID error: %+v", err)
|
||||
}
|
||||
|
||||
{
|
||||
uu, err := userRepo.FindByID(user.ID)
|
||||
assert(t, err == nil, "Owner.FindByID error: %+v", err)
|
||||
assert(t, len(uu.Roles) == 1, "Expected 1 role, got %d", len(uu.Roles))
|
||||
}
|
||||
|
||||
{
|
||||
users, err := userRepo.Find(&types.UserFilter{Query: "John User Doe"})
|
||||
assert(t, err == nil, "Owner.Find error: %+v", err)
|
||||
assert(t, len(users) == 1, "Owner.Find: expected 1 user, got %d", len(users))
|
||||
assert(t, len(users[0].Roles) == 1, "Owner.Find: expected 1 role, got %d", len(users[0].Roles))
|
||||
}
|
||||
return errors.New("Rollback")
|
||||
}), "expected rollback error")
|
||||
}
|
||||
33
system/internal/repository/util.go
Normal file
33
system/internal/repository/util.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r repository) updateColumnByID(tableName, columnName string, value interface{}, id uint64) (err error) {
|
||||
return exec(r.db().Exec(
|
||||
fmt.Sprintf("UPDATE %s SET %s = ? WHERE id = ?", tableName, columnName),
|
||||
value,
|
||||
id))
|
||||
}
|
||||
|
||||
func exec(_ interface{}, err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Returns err if set otherwise it returns nerr if not valid
|
||||
func isFound(err error, valid bool, nerr error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !valid {
|
||||
return nerr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func timeNowPtr() *time.Time {
|
||||
n := time.Now()
|
||||
return &n
|
||||
}
|
||||
122
system/internal/service/application.go
Normal file
122
system/internal/service/application.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
application struct {
|
||||
db *factory.DB
|
||||
ctx context.Context
|
||||
|
||||
prm PermissionsService
|
||||
|
||||
application repository.ApplicationRepository
|
||||
}
|
||||
|
||||
ApplicationService interface {
|
||||
With(ctx context.Context) ApplicationService
|
||||
|
||||
FindByID(applicationID uint64) (*types.Application, error)
|
||||
Find() (types.ApplicationSet, error)
|
||||
|
||||
Create(application *types.Application) (*types.Application, error)
|
||||
Update(application *types.Application) (*types.Application, error)
|
||||
DeleteByID(id uint64) error
|
||||
}
|
||||
)
|
||||
|
||||
func Application() ApplicationService {
|
||||
return (&application{
|
||||
prm: DefaultPermissions,
|
||||
}).With(context.Background())
|
||||
}
|
||||
|
||||
func (svc *application) With(ctx context.Context) ApplicationService {
|
||||
db := repository.DB(ctx)
|
||||
return &application{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
prm: svc.prm.With(ctx),
|
||||
application: repository.Application(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *application) FindByID(id uint64) (*types.Application, error) {
|
||||
app, err := svc.application.FindByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !svc.prm.CanReadApplication(app) {
|
||||
return nil, errors.New("Not allowed to access application")
|
||||
}
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
func (svc *application) Find() (types.ApplicationSet, error) {
|
||||
apps, err := svc.application.Find()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := []*types.Application{}
|
||||
for _, app := range apps {
|
||||
if svc.prm.CanReadApplication(app) {
|
||||
ret = append(ret, app)
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (svc *application) Create(mod *types.Application) (*types.Application, error) {
|
||||
if !svc.prm.CanCreateApplication() {
|
||||
return nil, errors.New("Not allowed to create application")
|
||||
}
|
||||
return svc.application.Create(mod)
|
||||
}
|
||||
|
||||
func (svc *application) Update(mod *types.Application) (t *types.Application, err error) {
|
||||
if !svc.prm.CanUpdateApplication(mod) {
|
||||
return nil, errors.New("Not allowed to update application")
|
||||
}
|
||||
|
||||
// @todo: make sure archived & deleted entries can not be edited
|
||||
|
||||
return t, svc.db.Transaction(func() (err error) {
|
||||
if t, err = svc.application.FindByID(mod.ID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Assign changed values
|
||||
t.Name = mod.Name
|
||||
t.Enabled = mod.Enabled
|
||||
t.Unify = mod.Unify
|
||||
|
||||
if t, err = svc.application.Update(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (svc *application) DeleteByID(id uint64) error {
|
||||
// @todo: make history unavailable
|
||||
// @todo: notify users that application has been removed (remove from web UI)
|
||||
|
||||
app := &types.Application{ID: id}
|
||||
if !svc.prm.CanDeleteApplication(app) {
|
||||
return errors.New("Not allowed to delete application")
|
||||
}
|
||||
return svc.application.DeleteByID(id)
|
||||
}
|
||||
|
||||
var _ ApplicationService = &application{}
|
||||
185
system/internal/service/auth.go
Normal file
185
system/internal/service/auth.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/markbates/goth"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
auth struct {
|
||||
db db
|
||||
ctx context.Context
|
||||
|
||||
credentials repository.CredentialsRepository
|
||||
users repository.UserRepository
|
||||
}
|
||||
|
||||
AuthService interface {
|
||||
With(ctx context.Context) AuthService
|
||||
|
||||
Social(profile goth.User) (*types.User, error)
|
||||
|
||||
CheckPassword(username, password string) (*types.User, error)
|
||||
ChangePassword(user *types.User, password string) error
|
||||
CheckCredentials(credentialsID uint64, secret string) (*types.User, error)
|
||||
RevokeCredentialsByID(user *types.User, credentialsID uint64) error
|
||||
}
|
||||
)
|
||||
|
||||
func Auth() AuthService {
|
||||
return (&auth{}).With(context.Background())
|
||||
}
|
||||
|
||||
func (svc *auth) With(ctx context.Context) AuthService {
|
||||
db := repository.DB(ctx)
|
||||
return &auth{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
|
||||
credentials: repository.Credentials(ctx, db),
|
||||
users: repository.User(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
// Social user verifies existance by using email value from social profile and creates user if needed
|
||||
//
|
||||
// It does not update user's info
|
||||
func (svc *auth) Social(profile goth.User) (u *types.User, err error) {
|
||||
var kind types.CredentialsKind
|
||||
|
||||
switch profile.Provider {
|
||||
case "facebook", "gplus", "github", "linkedin":
|
||||
kind = types.CredentialsKind(profile.Provider)
|
||||
default:
|
||||
return nil, errors.New("Unsupported provider")
|
||||
}
|
||||
|
||||
if profile.Email == "" {
|
||||
return nil, errors.New("Can not use profile data without an email")
|
||||
}
|
||||
|
||||
return u, svc.db.Transaction(func() error {
|
||||
var c *types.Credentials
|
||||
if cc, err := svc.credentials.FindByCredentials(kind, profile.UserID); err == nil {
|
||||
// Credentials found, load user
|
||||
for _, c := range cc {
|
||||
if !c.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
if u, err = svc.users.FindByID(c.OwnerID); err != nil {
|
||||
return nil
|
||||
} else if u.Valid() && u.Email != profile.Email {
|
||||
return errors.Errorf(
|
||||
"Refusing to authenticate with non matching emails (profile: %v, db: %v) on credentials (ID: %v)",
|
||||
profile.Email,
|
||||
u.Email,
|
||||
c.ID)
|
||||
} else if u.Valid() {
|
||||
// Valid user, matching emails. Bingo!
|
||||
return nil
|
||||
} else {
|
||||
// Scenario: linked to an invalid user
|
||||
u = nil
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If we could not find anything useful,
|
||||
// we can search for user via email
|
||||
} else {
|
||||
// A serious error occured, bail out...
|
||||
return err
|
||||
}
|
||||
|
||||
// Find user via his email
|
||||
if u, err = svc.users.FindByEmail(profile.Email); err == repository.ErrUserNotFound {
|
||||
// In case we do not have this email, create a new user
|
||||
u = &types.User{
|
||||
Email: profile.Email,
|
||||
Name: profile.Name,
|
||||
Username: profile.NickName,
|
||||
Handle: profile.NickName,
|
||||
}
|
||||
|
||||
if u, err = svc.users.Create(u); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c = &types.Credentials{
|
||||
Kind: kind,
|
||||
OwnerID: u.ID,
|
||||
Credentials: profile.UserID,
|
||||
}
|
||||
|
||||
if !profile.ExpiresAt.IsZero() {
|
||||
// Copy expiration date when provided
|
||||
c.ExpiresAt = &profile.ExpiresAt
|
||||
}
|
||||
|
||||
if c, err = svc.credentials.Create(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf(
|
||||
"Autheticated user (%v, %v) via %s, created user and credentials (%v)",
|
||||
u.ID,
|
||||
u.Email,
|
||||
profile.Provider,
|
||||
c.ID,
|
||||
)
|
||||
|
||||
// Owner created
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else if !u.Valid() {
|
||||
return errors.Errorf(
|
||||
"Social login to an invalid/suspended user (user ID: %v)",
|
||||
u.ID,
|
||||
)
|
||||
}
|
||||
|
||||
log.Printf(
|
||||
"Autheticated user (%v, %v) via %s, existing user",
|
||||
u.ID,
|
||||
u.Email,
|
||||
profile.Provider,
|
||||
)
|
||||
|
||||
// Owner loaded, carry on.
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CheckPassword verifies username/password combination
|
||||
//
|
||||
// Expects plain text password as an input
|
||||
func (svc *auth) CheckPassword(username, password string) (*types.User, error) {
|
||||
panic("svc.auth.CheckPassword, not implemented")
|
||||
}
|
||||
|
||||
// ChangePassword (soft) deletes old password entry and creates a new one
|
||||
//
|
||||
// Expects plain text password as an input
|
||||
func (svc *auth) ChangePassword(user *types.User, password string) error {
|
||||
panic("svc.auth.ChangePassword, not implemented")
|
||||
}
|
||||
|
||||
// CheckCredentials searches for credentials/secret combination and returns loaded user if successful
|
||||
func (svc *auth) CheckCredentials(credentialsID uint64, secret string) (*types.User, error) {
|
||||
panic("svc.auth.CheckCredentials, not implemented")
|
||||
}
|
||||
|
||||
// RevokeCredentialsByID (soft) deletes credentials by id
|
||||
func (svc *auth) RevokeCredentialsByID(user *types.User, credentialsID uint64) error {
|
||||
panic("svc.auth.RevokeCredentialsByID, not implemented")
|
||||
}
|
||||
|
||||
var _ AuthService = &auth{}
|
||||
85
system/internal/service/auth_test.go
Normal file
85
system/internal/service/auth_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/markbates/goth"
|
||||
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
repomock "github.com/crusttech/crust/system/internal/repository/mocks"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
func TestSocialSigninWithExistingCredentials(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
var u = &types.User{ID: 300000, Email: "foo@example.tld"}
|
||||
var c = &types.Credentials{ID: 200000, OwnerID: u.ID}
|
||||
var p = goth.User{UserID: "some-profile-id", Provider: "gplus", Email: u.Email}
|
||||
|
||||
crdRpoMock := repomock.NewMockCredentialsRepository(mockCtrl)
|
||||
crdRpoMock.EXPECT().
|
||||
FindByCredentials(types.CredentialsKindGPlus, p.UserID).
|
||||
Times(1).
|
||||
Return(types.CredentialsSet{c}, nil)
|
||||
|
||||
usrRpoMock := repomock.NewMockUserRepository(mockCtrl)
|
||||
usrRpoMock.EXPECT().FindByID(u.ID).Times(1).Return(u, nil)
|
||||
|
||||
svc := &auth{
|
||||
db: &mockDB{},
|
||||
users: usrRpoMock,
|
||||
credentials: crdRpoMock,
|
||||
}
|
||||
|
||||
{
|
||||
auser, err := svc.Social(p)
|
||||
assert(t, err == nil, "Auth.Social error: %+v", err)
|
||||
assert(t, auser.ID == u.ID, "Did not receive expected user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialSigninWithNewUserCredentials(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
var u = &types.User{ID: 300000, Email: "foo@example.tld"}
|
||||
var c = &types.Credentials{ID: 200000, OwnerID: u.ID}
|
||||
var p = goth.User{UserID: "some-profile-id", Provider: "gplus", Email: u.Email}
|
||||
|
||||
crdRpoMock := repomock.NewMockCredentialsRepository(mockCtrl)
|
||||
crdRpoMock.EXPECT().
|
||||
FindByCredentials(types.CredentialsKindGPlus, p.UserID).
|
||||
Times(1).
|
||||
Return(types.CredentialsSet{}, nil)
|
||||
|
||||
crdRpoMock.EXPECT().
|
||||
Create(&types.Credentials{Kind: types.CredentialsKindGPlus, OwnerID: u.ID, Credentials: p.UserID}).
|
||||
Times(1).
|
||||
Return(c, nil)
|
||||
|
||||
usrRpoMock := repomock.NewMockUserRepository(mockCtrl)
|
||||
usrRpoMock.EXPECT().
|
||||
FindByEmail(u.Email).
|
||||
Times(1).
|
||||
Return(nil, repository.ErrUserNotFound)
|
||||
|
||||
usrRpoMock.EXPECT().
|
||||
Create(&types.User{Email: "foo@example.tld"}).
|
||||
Times(1).
|
||||
Return(u, nil)
|
||||
|
||||
svc := &auth{
|
||||
db: &mockDB{},
|
||||
users: usrRpoMock,
|
||||
credentials: crdRpoMock,
|
||||
}
|
||||
|
||||
{
|
||||
auser, err := svc.Social(p)
|
||||
assert(t, err == nil, "Auth.Social error: %+v", err)
|
||||
assert(t, auser.ID == u.ID, "Did not receive expected user")
|
||||
}
|
||||
}
|
||||
9
system/internal/service/error.go
Normal file
9
system/internal/service/error.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package service
|
||||
|
||||
type (
|
||||
serviceError string
|
||||
)
|
||||
|
||||
func (e serviceError) Error() string {
|
||||
return "crust.messaging.service." + string(e)
|
||||
}
|
||||
58
system/internal/service/main_test.go
Normal file
58
system/internal/service/main_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/namsral/flag"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
systemMigrate "github.com/crusttech/crust/system/db"
|
||||
)
|
||||
|
||||
type mockDB struct{}
|
||||
|
||||
func (mockDB) Transaction(callback func() error) error { return callback() }
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// @todo this is a very optimistic initialization, make it more robust
|
||||
godotenv.Load("../../.env")
|
||||
|
||||
prefix := "system"
|
||||
dsn := ""
|
||||
|
||||
p := func(s string) string {
|
||||
return prefix + "-" + s
|
||||
}
|
||||
|
||||
flag.StringVar(&dsn, p("db-dsn"), "crust:crust@tcp(db1:3306)/crust?collation=utf8mb4_general_ci", "DSN for database connection")
|
||||
flag.Parse()
|
||||
|
||||
factory.Database.Add("default", dsn)
|
||||
factory.Database.Add("system", dsn)
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
db.Profiler = &factory.Database.ProfilerStdout
|
||||
|
||||
// migrate database schema
|
||||
if err := systemMigrate.Migrate(db); err != nil {
|
||||
log.Printf("Error running migrations: %+v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func assert(t *testing.T, ok bool, format string, args ...interface{}) bool {
|
||||
if !ok {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
caller := fmt.Sprintf("\nAsserted at:%s:%d", file, line)
|
||||
|
||||
t.Fatalf(format+caller, args...)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
94
system/internal/service/organisation.go
Normal file
94
system/internal/service/organisation.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
organisation struct {
|
||||
db *factory.DB
|
||||
ctx context.Context
|
||||
|
||||
rpo repository.OrganisationRepository
|
||||
}
|
||||
|
||||
OrganisationService interface {
|
||||
With(ctx context.Context) OrganisationService
|
||||
|
||||
FindByID(organisationID uint64) (*types.Organisation, error)
|
||||
Find(filter *types.OrganisationFilter) ([]*types.Organisation, error)
|
||||
|
||||
Create(organisation *types.Organisation) (*types.Organisation, error)
|
||||
Update(organisation *types.Organisation) (*types.Organisation, error)
|
||||
|
||||
Archive(ID uint64) error
|
||||
Unarchive(ID uint64) error
|
||||
Delete(ID uint64) error
|
||||
}
|
||||
)
|
||||
|
||||
func Organisation() OrganisationService {
|
||||
return (&organisation{}).With(context.Background())
|
||||
}
|
||||
|
||||
func (svc *organisation) With(ctx context.Context) OrganisationService {
|
||||
db := repository.DB(ctx)
|
||||
return &organisation{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
rpo: repository.Organisation(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *organisation) FindByID(id uint64) (*types.Organisation, error) {
|
||||
// @todo: permission check if current user can read organisation
|
||||
return svc.rpo.FindOrganisationByID(id)
|
||||
}
|
||||
|
||||
func (svc *organisation) Find(filter *types.OrganisationFilter) ([]*types.Organisation, error) {
|
||||
// @todo: permission check to return only organisations that organisation has access to
|
||||
// @todo: actual searching not just a full select
|
||||
return svc.rpo.FindOrganisations(filter)
|
||||
}
|
||||
|
||||
func (svc *organisation) Create(mod *types.Organisation) (*types.Organisation, error) {
|
||||
// @todo: permission check if current user can add/edit organisation
|
||||
// @todo: make sure archived & deleted entries can not be edited
|
||||
|
||||
return svc.rpo.CreateOrganisation(mod)
|
||||
}
|
||||
|
||||
func (svc *organisation) Update(mod *types.Organisation) (*types.Organisation, error) {
|
||||
// @todo: permission check if current user can add/edit organisation
|
||||
// @todo: make sure archived & deleted entries can not be edited
|
||||
|
||||
return svc.rpo.UpdateOrganisation(mod)
|
||||
}
|
||||
|
||||
func (svc *organisation) Delete(id uint64) error {
|
||||
// @todo: permissions check if current user can remove organisation
|
||||
// @todo: make history unavailable
|
||||
// @todo: notify users that organisation has been removed (remove from web UI)
|
||||
return svc.rpo.DeleteOrganisationByID(id)
|
||||
}
|
||||
|
||||
func (svc *organisation) Archive(id uint64) error {
|
||||
// @todo: make history unavailable
|
||||
// @todo: notify users that organisation has been removed (remove from web UI)
|
||||
// @todo: permissions check if current user can archive organisation
|
||||
return svc.rpo.ArchiveOrganisationByID(id)
|
||||
}
|
||||
|
||||
func (svc *organisation) Unarchive(id uint64) error {
|
||||
// @todo: make history unavailable
|
||||
// @todo: notify users that organisation has been removed (remove from web UI)
|
||||
// @todo: permissions check if current user can unarchive organisation
|
||||
return svc.rpo.UnarchiveOrganisationByID(id)
|
||||
}
|
||||
|
||||
var _ OrganisationService = &organisation{}
|
||||
135
system/internal/service/organisation_mock_test.go
Normal file
135
system/internal/service/organisation_mock_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/internal/service/organisation.go
|
||||
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
context "context"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockOrganisationService is a mock of OrganisationService interface
|
||||
type MockOrganisationService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockOrganisationServiceMockRecorder
|
||||
}
|
||||
|
||||
// MockOrganisationServiceMockRecorder is the mock recorder for MockOrganisationService
|
||||
type MockOrganisationServiceMockRecorder struct {
|
||||
mock *MockOrganisationService
|
||||
}
|
||||
|
||||
// NewMockOrganisationService creates a new mock instance
|
||||
func NewMockOrganisationService(ctrl *gomock.Controller) *MockOrganisationService {
|
||||
mock := &MockOrganisationService{ctrl: ctrl}
|
||||
mock.recorder = &MockOrganisationServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockOrganisationService) EXPECT() *MockOrganisationServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockOrganisationService) With(ctx context.Context) OrganisationService {
|
||||
ret := m.ctrl.Call(m, "With", ctx)
|
||||
ret0, _ := ret[0].(OrganisationService)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockOrganisationServiceMockRecorder) With(ctx interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockOrganisationService)(nil).With), ctx)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockOrganisationService) FindByID(organisationID uint64) (*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", organisationID)
|
||||
ret0, _ := ret[0].(*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockOrganisationServiceMockRecorder) FindByID(organisationID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockOrganisationService)(nil).FindByID), organisationID)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockOrganisationService) Find(filter *types.OrganisationFilter) ([]*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "Find", filter)
|
||||
ret0, _ := ret[0].([]*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockOrganisationServiceMockRecorder) Find(filter interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockOrganisationService)(nil).Find), filter)
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockOrganisationService) Create(organisation *types.Organisation) (*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "Create", organisation)
|
||||
ret0, _ := ret[0].(*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockOrganisationServiceMockRecorder) Create(organisation interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockOrganisationService)(nil).Create), organisation)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockOrganisationService) Update(organisation *types.Organisation) (*types.Organisation, error) {
|
||||
ret := m.ctrl.Call(m, "Update", organisation)
|
||||
ret0, _ := ret[0].(*types.Organisation)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockOrganisationServiceMockRecorder) Update(organisation interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockOrganisationService)(nil).Update), organisation)
|
||||
}
|
||||
|
||||
// Archive mocks base method
|
||||
func (m *MockOrganisationService) Archive(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Archive", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Archive indicates an expected call of Archive
|
||||
func (mr *MockOrganisationServiceMockRecorder) Archive(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Archive", reflect.TypeOf((*MockOrganisationService)(nil).Archive), ID)
|
||||
}
|
||||
|
||||
// Unarchive mocks base method
|
||||
func (m *MockOrganisationService) Unarchive(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Unarchive", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Unarchive indicates an expected call of Unarchive
|
||||
func (mr *MockOrganisationServiceMockRecorder) Unarchive(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unarchive", reflect.TypeOf((*MockOrganisationService)(nil).Unarchive), ID)
|
||||
}
|
||||
|
||||
// Delete mocks base method
|
||||
func (m *MockOrganisationService) Delete(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Delete", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete
|
||||
func (mr *MockOrganisationServiceMockRecorder) Delete(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockOrganisationService)(nil).Delete), ID)
|
||||
}
|
||||
135
system/internal/service/permissions.go
Normal file
135
system/internal/service/permissions.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
internalRules "github.com/crusttech/crust/internal/rules"
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
permissions struct {
|
||||
db db
|
||||
ctx context.Context
|
||||
|
||||
rules RulesService
|
||||
}
|
||||
|
||||
PermissionsService interface {
|
||||
With(context.Context) PermissionsService
|
||||
|
||||
Effective() (ee []effectivePermission, err error)
|
||||
|
||||
CanCreateOrganisation() bool
|
||||
CanCreateRole() bool
|
||||
CanCreateApplication() bool
|
||||
|
||||
CanReadRole(rl *types.Role) bool
|
||||
CanUpdateRole(rl *types.Role) bool
|
||||
CanDeleteRole(rl *types.Role) bool
|
||||
CanManageRoleMembers(rl *types.Role) bool
|
||||
|
||||
CanReadApplication(app *types.Application) bool
|
||||
CanUpdateApplication(app *types.Application) bool
|
||||
CanDeleteApplication(app *types.Application) bool
|
||||
}
|
||||
|
||||
effectivePermission struct {
|
||||
Resource string `json:"resource"`
|
||||
Operation string `json:"operation"`
|
||||
Allow bool `json:"allow"`
|
||||
}
|
||||
)
|
||||
|
||||
func Permissions() PermissionsService {
|
||||
return (&permissions{
|
||||
rules: DefaultRules,
|
||||
}).With(context.Background())
|
||||
}
|
||||
|
||||
func (p *permissions) With(ctx context.Context) PermissionsService {
|
||||
db := repository.DB(ctx)
|
||||
return &permissions{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
|
||||
rules: p.rules.With(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *permissions) Effective() (ee []effectivePermission, err error) {
|
||||
ep := func(res, op string, allow bool) effectivePermission {
|
||||
return effectivePermission{
|
||||
Resource: res,
|
||||
Operation: op,
|
||||
Allow: allow,
|
||||
}
|
||||
}
|
||||
|
||||
ee = append(ee, ep("system", "access", p.CanAccess()))
|
||||
ee = append(ee, ep("system", "application.create", p.CanCreateApplication()))
|
||||
ee = append(ee, ep("system", "role.create", p.CanCreateRole()))
|
||||
ee = append(ee, ep("system", "organisation.create", p.CanCreateOrganisation()))
|
||||
ee = append(ee, ep("system", "grant", p.CanCreateRole()))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *permissions) CanAccess() bool {
|
||||
return p.checkAccess("system", "access")
|
||||
}
|
||||
|
||||
func (p *permissions) CanCreateOrganisation() bool {
|
||||
return p.checkAccess("system", "organisation.create")
|
||||
}
|
||||
|
||||
func (p *permissions) CanCreateRole() bool {
|
||||
return p.checkAccess("system", "role.create")
|
||||
}
|
||||
|
||||
func (p *permissions) CanCreateApplication() bool {
|
||||
return p.checkAccess("system", "application.create")
|
||||
}
|
||||
|
||||
func (p *permissions) CanReadRole(rl *types.Role) bool {
|
||||
return p.checkAccess(rl.Resource().String(), "read", p.allow())
|
||||
}
|
||||
|
||||
func (p *permissions) CanUpdateRole(rl *types.Role) bool {
|
||||
return p.checkAccess(rl.Resource().String(), "update")
|
||||
}
|
||||
|
||||
func (p *permissions) CanDeleteRole(rl *types.Role) bool {
|
||||
return p.checkAccess(rl.Resource().String(), "delete")
|
||||
}
|
||||
|
||||
func (p *permissions) CanManageRoleMembers(rl *types.Role) bool {
|
||||
return p.checkAccess(rl.Resource().String(), "members.manage")
|
||||
}
|
||||
|
||||
func (p *permissions) CanReadApplication(app *types.Application) bool {
|
||||
return p.checkAccess(app.Resource().String(), "read", p.allow())
|
||||
}
|
||||
|
||||
func (p *permissions) CanUpdateApplication(app *types.Application) bool {
|
||||
return p.checkAccess(app.Resource().String(), "update")
|
||||
}
|
||||
|
||||
func (p *permissions) CanDeleteApplication(app *types.Application) bool {
|
||||
return p.checkAccess(app.Resource().String(), "delete")
|
||||
}
|
||||
|
||||
func (p *permissions) checkAccess(resource string, operation string, fallbacks ...internalRules.CheckAccessFunc) bool {
|
||||
access := p.rules.Check(resource, operation, fallbacks...)
|
||||
if access == internalRules.Allow {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p permissions) allow() func() internalRules.Access {
|
||||
return func() internalRules.Access {
|
||||
return internalRules.Allow
|
||||
}
|
||||
}
|
||||
179
system/internal/service/role.go
Normal file
179
system/internal/service/role.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type (
|
||||
role struct {
|
||||
db *factory.DB
|
||||
ctx context.Context
|
||||
|
||||
prm PermissionsService
|
||||
|
||||
role repository.RoleRepository
|
||||
}
|
||||
|
||||
RoleService interface {
|
||||
With(ctx context.Context) RoleService
|
||||
|
||||
FindByID(roleID uint64) (*types.Role, error)
|
||||
Find(filter *types.RoleFilter) ([]*types.Role, error)
|
||||
|
||||
Create(role *types.Role) (*types.Role, error)
|
||||
Update(role *types.Role) (*types.Role, error)
|
||||
Merge(roleID, targetroleID uint64) error
|
||||
Move(roleID, organisationID uint64) error
|
||||
|
||||
Archive(ID uint64) error
|
||||
Unarchive(ID uint64) error
|
||||
Delete(ID uint64) error
|
||||
|
||||
MemberList(roleID uint64) ([]*types.RoleMember, error)
|
||||
MemberAdd(roleID, userID uint64) error
|
||||
MemberRemove(roleID, userID uint64) error
|
||||
}
|
||||
)
|
||||
|
||||
func Role() RoleService {
|
||||
return (&role{
|
||||
prm: DefaultPermissions,
|
||||
}).With(context.Background())
|
||||
}
|
||||
|
||||
func (svc *role) With(ctx context.Context) RoleService {
|
||||
db := repository.DB(ctx)
|
||||
return &role{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
|
||||
prm: svc.prm.With(ctx),
|
||||
|
||||
role: repository.Role(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *role) FindByID(id uint64) (*types.Role, error) {
|
||||
role, err := svc.role.FindByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !svc.prm.CanReadRole(role) {
|
||||
return nil, errors.New("Not allowed to read role")
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (svc *role) Find(filter *types.RoleFilter) ([]*types.Role, error) {
|
||||
roles, err := svc.role.Find(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := []*types.Role{}
|
||||
for _, role := range roles {
|
||||
if svc.prm.CanReadRole(role) {
|
||||
ret = append(ret, role)
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (svc *role) Create(mod *types.Role) (*types.Role, error) {
|
||||
if !svc.prm.CanCreateRole() {
|
||||
return nil, errors.New("Not allowed to create role")
|
||||
}
|
||||
return svc.role.Create(mod)
|
||||
}
|
||||
|
||||
func (svc *role) Update(mod *types.Role) (t *types.Role, err error) {
|
||||
if !svc.prm.CanUpdateRole(mod) {
|
||||
return nil, errors.New("Not allowed to update role")
|
||||
}
|
||||
|
||||
// @todo: make sure archived & deleted entries can not be edited
|
||||
|
||||
return t, svc.db.Transaction(func() (err error) {
|
||||
if t, err = svc.role.FindByID(mod.ID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Assign changed values
|
||||
t.Name = mod.Name
|
||||
t.Handle = mod.Handle
|
||||
|
||||
if t, err = svc.role.Update(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (svc *role) Delete(id uint64) error {
|
||||
// @todo: make history unavailable
|
||||
// @todo: notify users that role has been removed (remove from web UI)
|
||||
|
||||
rl := &types.Role{ID: id}
|
||||
if !svc.prm.CanDeleteRole(rl) {
|
||||
return errors.New("Not allowed to delete role")
|
||||
}
|
||||
return svc.role.DeleteByID(id)
|
||||
}
|
||||
|
||||
func (svc *role) Archive(id uint64) error {
|
||||
// @todo: make history unavailable
|
||||
// @todo: notify users that role has been removed (remove from web UI)
|
||||
// @todo: permissions check if current user can remove role
|
||||
return svc.role.ArchiveByID(id)
|
||||
}
|
||||
|
||||
func (svc *role) Unarchive(id uint64) error {
|
||||
// @todo: permissions check if current user can unarchive role
|
||||
// @todo: make history accessible
|
||||
// @todo: notify users that role has been unarchived
|
||||
return svc.role.UnarchiveByID(id)
|
||||
}
|
||||
|
||||
func (svc *role) Merge(id, targetroleID uint64) error {
|
||||
// @todo: permission check if current user can merge role
|
||||
return svc.role.MergeByID(id, targetroleID)
|
||||
}
|
||||
|
||||
func (svc *role) Move(id, targetOrganisationID uint64) error {
|
||||
// @todo: permission check if current user can move role to another organisation
|
||||
return svc.role.MoveByID(id, targetOrganisationID)
|
||||
}
|
||||
|
||||
func (svc *role) MemberList(roleID uint64) ([]*types.RoleMember, error) {
|
||||
rl := &types.Role{ID: roleID}
|
||||
if !svc.prm.CanManageRoleMembers(rl) {
|
||||
return nil, errors.New("Not allowed to manage role members")
|
||||
}
|
||||
return svc.role.MemberFindByRoleID(roleID)
|
||||
}
|
||||
|
||||
func (svc *role) MemberAdd(roleID, userID uint64) error {
|
||||
rl := &types.Role{ID: roleID}
|
||||
if !svc.prm.CanManageRoleMembers(rl) {
|
||||
return errors.New("Not allowed to manage role members")
|
||||
}
|
||||
return svc.role.MemberAddByID(roleID, userID)
|
||||
}
|
||||
|
||||
func (svc *role) MemberRemove(roleID, userID uint64) error {
|
||||
rl := &types.Role{ID: roleID}
|
||||
if !svc.prm.CanManageRoleMembers(rl) {
|
||||
return errors.New("Not allowed to manage role members")
|
||||
}
|
||||
return svc.role.MemberRemoveByID(roleID, userID)
|
||||
}
|
||||
|
||||
var _ RoleService = &role{}
|
||||
196
system/internal/service/role_mock_test.go
Normal file
196
system/internal/service/role_mock_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/internal/service/role.go
|
||||
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
context "context"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockRoleService is a mock of RoleService interface
|
||||
type MockRoleService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRoleServiceMockRecorder
|
||||
}
|
||||
|
||||
// MockRoleServiceMockRecorder is the mock recorder for MockRoleService
|
||||
type MockRoleServiceMockRecorder struct {
|
||||
mock *MockRoleService
|
||||
}
|
||||
|
||||
// NewMockRoleService creates a new mock instance
|
||||
func NewMockRoleService(ctrl *gomock.Controller) *MockRoleService {
|
||||
mock := &MockRoleService{ctrl: ctrl}
|
||||
mock.recorder = &MockRoleServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockRoleService) EXPECT() *MockRoleServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockRoleService) With(ctx context.Context) RoleService {
|
||||
ret := m.ctrl.Call(m, "With", ctx)
|
||||
ret0, _ := ret[0].(RoleService)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockRoleServiceMockRecorder) With(ctx interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockRoleService)(nil).With), ctx)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockRoleService) FindByID(roleID uint64) (*types.Role, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", roleID)
|
||||
ret0, _ := ret[0].(*types.Role)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockRoleServiceMockRecorder) FindByID(roleID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockRoleService)(nil).FindByID), roleID)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockRoleService) Find(filter *types.RoleFilter) ([]*types.Role, error) {
|
||||
ret := m.ctrl.Call(m, "Find", filter)
|
||||
ret0, _ := ret[0].([]*types.Role)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockRoleServiceMockRecorder) Find(filter interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockRoleService)(nil).Find), filter)
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockRoleService) Create(role *types.Role) (*types.Role, error) {
|
||||
ret := m.ctrl.Call(m, "Create", role)
|
||||
ret0, _ := ret[0].(*types.Role)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockRoleServiceMockRecorder) Create(role interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockRoleService)(nil).Create), role)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockRoleService) Update(role *types.Role) (*types.Role, error) {
|
||||
ret := m.ctrl.Call(m, "Update", role)
|
||||
ret0, _ := ret[0].(*types.Role)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockRoleServiceMockRecorder) Update(role interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockRoleService)(nil).Update), role)
|
||||
}
|
||||
|
||||
// Merge mocks base method
|
||||
func (m *MockRoleService) Merge(roleID, targetroleID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Merge", roleID, targetroleID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Merge indicates an expected call of Merge
|
||||
func (mr *MockRoleServiceMockRecorder) Merge(roleID, targetroleID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Merge", reflect.TypeOf((*MockRoleService)(nil).Merge), roleID, targetroleID)
|
||||
}
|
||||
|
||||
// Move mocks base method
|
||||
func (m *MockRoleService) Move(roleID, organisationID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Move", roleID, organisationID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Move indicates an expected call of Move
|
||||
func (mr *MockRoleServiceMockRecorder) Move(roleID, organisationID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Move", reflect.TypeOf((*MockRoleService)(nil).Move), roleID, organisationID)
|
||||
}
|
||||
|
||||
// Archive mocks base method
|
||||
func (m *MockRoleService) Archive(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Archive", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Archive indicates an expected call of Archive
|
||||
func (mr *MockRoleServiceMockRecorder) Archive(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Archive", reflect.TypeOf((*MockRoleService)(nil).Archive), ID)
|
||||
}
|
||||
|
||||
// Unarchive mocks base method
|
||||
func (m *MockRoleService) Unarchive(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Unarchive", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Unarchive indicates an expected call of Unarchive
|
||||
func (mr *MockRoleServiceMockRecorder) Unarchive(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unarchive", reflect.TypeOf((*MockRoleService)(nil).Unarchive), ID)
|
||||
}
|
||||
|
||||
// Delete mocks base method
|
||||
func (m *MockRoleService) Delete(ID uint64) error {
|
||||
ret := m.ctrl.Call(m, "Delete", ID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete
|
||||
func (mr *MockRoleServiceMockRecorder) Delete(ID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockRoleService)(nil).Delete), ID)
|
||||
}
|
||||
|
||||
// MemberList mocks base method
|
||||
func (m *MockRoleService) MemberList(roleID uint64) ([]*types.RoleMember, error) {
|
||||
ret := m.ctrl.Call(m, "MemberList", roleID)
|
||||
ret0, _ := ret[0].([]*types.RoleMember)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MemberList indicates an expected call of MemberList
|
||||
func (mr *MockRoleServiceMockRecorder) MemberList(roleID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MemberList", reflect.TypeOf((*MockRoleService)(nil).MemberList), roleID)
|
||||
}
|
||||
|
||||
// MemberAdd mocks base method
|
||||
func (m *MockRoleService) MemberAdd(roleID, userID uint64) error {
|
||||
ret := m.ctrl.Call(m, "MemberAdd", roleID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MemberAdd indicates an expected call of MemberAdd
|
||||
func (mr *MockRoleServiceMockRecorder) MemberAdd(roleID, userID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MemberAdd", reflect.TypeOf((*MockRoleService)(nil).MemberAdd), roleID, userID)
|
||||
}
|
||||
|
||||
// MemberRemove mocks base method
|
||||
func (m *MockRoleService) MemberRemove(roleID, userID uint64) error {
|
||||
ret := m.ctrl.Call(m, "MemberRemove", roleID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MemberRemove indicates an expected call of MemberRemove
|
||||
func (mr *MockRoleServiceMockRecorder) MemberRemove(roleID, userID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MemberRemove", reflect.TypeOf((*MockRoleService)(nil).MemberRemove), roleID, userID)
|
||||
}
|
||||
141
system/internal/service/rules.go
Normal file
141
system/internal/service/rules.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
internalRules "github.com/crusttech/crust/internal/rules"
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
const (
|
||||
delimiter = ":"
|
||||
)
|
||||
|
||||
type (
|
||||
rules struct {
|
||||
db db
|
||||
ctx context.Context
|
||||
|
||||
resources internalRules.ResourcesInterface
|
||||
}
|
||||
|
||||
EffectiveRules struct {
|
||||
Resource string `json:"resource"`
|
||||
Operation string `json:"operation"`
|
||||
Allow bool `json:"allow"`
|
||||
}
|
||||
|
||||
RulesService interface {
|
||||
With(ctx context.Context) RulesService
|
||||
|
||||
List() (interface{}, error)
|
||||
Effective(filter string) ([]EffectiveRules, error)
|
||||
|
||||
Check(resource string, operation string, fallbacks ...internalRules.CheckAccessFunc) internalRules.Access
|
||||
|
||||
Read(roleID uint64) (interface{}, error)
|
||||
Update(roleID uint64, rules []internalRules.Rule) (interface{}, error)
|
||||
Delete(roleID uint64) (interface{}, error)
|
||||
}
|
||||
)
|
||||
|
||||
func Rules() RulesService {
|
||||
return (&rules{}).With(context.Background())
|
||||
}
|
||||
|
||||
func (p *rules) With(ctx context.Context) RulesService {
|
||||
db := repository.DB(ctx)
|
||||
return &rules{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
|
||||
resources: internalRules.NewResources(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *rules) List() (interface{}, error) {
|
||||
perms := []types.Permission{}
|
||||
for resource, operations := range permissionList {
|
||||
err := p.checkServiceAccess(resource)
|
||||
if err == nil {
|
||||
for ops := range operations {
|
||||
perms = append(perms, types.Permission{Resource: resource, Operation: ops})
|
||||
}
|
||||
}
|
||||
}
|
||||
return perms, nil
|
||||
}
|
||||
|
||||
func (p *rules) Effective(filter string) (eff []EffectiveRules, err error) {
|
||||
eff = []EffectiveRules{}
|
||||
for resource, operations := range permissionList {
|
||||
// err := p.checkServiceAccess(resource)
|
||||
if err == nil {
|
||||
for ops := range operations {
|
||||
eff = append(eff, EffectiveRules{
|
||||
Resource: resource,
|
||||
Operation: ops,
|
||||
Allow: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *rules) Check(resource string, operation string, fallbacks ...internalRules.CheckAccessFunc) internalRules.Access {
|
||||
return p.resources.Check(resource, operation, fallbacks...)
|
||||
}
|
||||
|
||||
func (p *rules) Read(roleID uint64) (interface{}, error) {
|
||||
ret, err := p.resources.Read(roleID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only display rules under granted scopes.
|
||||
rules := []internalRules.Rule{}
|
||||
for _, rule := range ret {
|
||||
err = p.checkServiceAccess(rule.Resource)
|
||||
if err == nil {
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (p *rules) Update(roleID uint64, rules []internalRules.Rule) (interface{}, error) {
|
||||
for _, rule := range rules {
|
||||
err := validatePermission(rule.Resource, rule.Operation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = p.checkServiceAccess(rule.Resource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
err := p.resources.Grant(roleID, rules)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.resources.Read(roleID)
|
||||
}
|
||||
|
||||
func (p *rules) Delete(roleID uint64) (interface{}, error) {
|
||||
return nil, p.resources.Delete(roleID)
|
||||
}
|
||||
|
||||
func (p *rules) checkServiceAccess(resource string) error {
|
||||
service := strings.Split(resource, delimiter)[0]
|
||||
|
||||
grant := p.resources.Check(service, "grant")
|
||||
if grant == internalRules.Allow {
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("No grant permissions for: %v", service)
|
||||
}
|
||||
160
system/internal/service/rules_test.go
Normal file
160
system/internal/service/rules_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
internalAuth "github.com/crusttech/crust/internal/auth"
|
||||
internalRules "github.com/crusttech/crust/internal/rules"
|
||||
. "github.com/crusttech/crust/internal/test"
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
func TestRules(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
// Create test user and role.
|
||||
user := &types.User{ID: 1337}
|
||||
role := &types.Role{ID: 123456, Name: "Test role"}
|
||||
|
||||
// Write user to context.
|
||||
ctx := internalAuth.SetIdentityToContext(context.Background(), user)
|
||||
|
||||
// Connect do DB.
|
||||
db := factory.Database.MustGet()
|
||||
|
||||
// Create resources interface.
|
||||
resources := internalRules.NewResources(ctx, db)
|
||||
|
||||
// Run tests in transaction to maintain DB state.
|
||||
Error(t, db.Transaction(func() error {
|
||||
db.Delete("sys_rules", "1=1")
|
||||
db.Insert("sys_user", user)
|
||||
db.Insert("sys_role", role)
|
||||
db.Insert("sys_role_member", types.RoleMember{RoleID: role.ID, UserID: user.ID})
|
||||
|
||||
// delete all for test roleID = 123456
|
||||
{
|
||||
err := resources.Delete(role.ID)
|
||||
NoError(t, err, "expected no error, got %+v", err)
|
||||
}
|
||||
|
||||
// Create rules service.
|
||||
rulesSvc := Rules().With(ctx)
|
||||
|
||||
// Update rules for test role, with error.
|
||||
{
|
||||
list := []internalRules.Rule{
|
||||
internalRules.Rule{Resource: "messaging:channel:1", Operation: "message.update.all", Value: internalRules.Allow},
|
||||
}
|
||||
_, err := rulesSvc.Update(role.ID, list)
|
||||
Error(t, err, "expected error == No Allow rule for messaging")
|
||||
}
|
||||
|
||||
// Insert `grant` permission for `messaging` and `system`.
|
||||
{
|
||||
list := []internalRules.Rule{
|
||||
internalRules.Rule{Resource: "system", Operation: "grant", Value: internalRules.Allow},
|
||||
internalRules.Rule{Resource: "messaging", Operation: "grant", Value: internalRules.Allow},
|
||||
}
|
||||
|
||||
err := resources.Grant(role.ID, list)
|
||||
NoError(t, err, "expected no error, got %v+", err)
|
||||
}
|
||||
|
||||
// List possible permissions with `messaging` and `system` grants.
|
||||
{
|
||||
ret, err := rulesSvc.List()
|
||||
NoError(t, err, "expected no error, got %+v", err)
|
||||
|
||||
perms := ret.([]types.Permission)
|
||||
|
||||
Assert(t, len(perms) > 0, "expected len(rules) > 0, got %v", len(perms))
|
||||
}
|
||||
|
||||
// Update rules for test role.
|
||||
{
|
||||
list := []internalRules.Rule{
|
||||
internalRules.Rule{Resource: "messaging:channel:*", Operation: "message.update.all", Value: internalRules.Allow},
|
||||
internalRules.Rule{Resource: "messaging:channel:1", Operation: "message.update.all", Value: internalRules.Deny},
|
||||
internalRules.Rule{Resource: "messaging:channel:2", Operation: "message.update.all"},
|
||||
internalRules.Rule{Resource: "system", Operation: "organisation.create", Value: internalRules.Allow},
|
||||
internalRules.Rule{Resource: "system:organisation:*", Operation: "access", Value: internalRules.Allow},
|
||||
internalRules.Rule{Resource: "messaging:channel", Operation: "message.update.all", Value: internalRules.Allow},
|
||||
}
|
||||
_, err := rulesSvc.Update(role.ID, list)
|
||||
NoError(t, err, "expected no error, got %+v", err)
|
||||
}
|
||||
|
||||
// Update with invalid roles
|
||||
{
|
||||
list := []internalRules.Rule{
|
||||
internalRules.Rule{Resource: "nosystem:channel:*", Operation: "message.update.all", Value: internalRules.Allow},
|
||||
}
|
||||
_, err := rulesSvc.Update(role.ID, list)
|
||||
Error(t, err, "expected error")
|
||||
|
||||
list = []internalRules.Rule{
|
||||
internalRules.Rule{Resource: "messaging:noresource:1", Operation: "message.update.all", Value: internalRules.Deny},
|
||||
}
|
||||
_, err = rulesSvc.Update(role.ID, list)
|
||||
Error(t, err, "expected error")
|
||||
|
||||
list = []internalRules.Rule{
|
||||
internalRules.Rule{Resource: "messaging:channel:", Operation: "message.update.all"},
|
||||
}
|
||||
_, err = rulesSvc.Update(role.ID, list)
|
||||
Error(t, err, "expected error")
|
||||
|
||||
list = []internalRules.Rule{
|
||||
internalRules.Rule{Resource: "system:organisation:*", Operation: "invalid", Value: internalRules.Allow},
|
||||
}
|
||||
_, err = rulesSvc.Update(role.ID, list)
|
||||
Error(t, err, "expected error")
|
||||
}
|
||||
|
||||
// Read rules for test role.
|
||||
{
|
||||
ret, err := rulesSvc.Read(role.ID)
|
||||
NoError(t, err, "expected no error, got %+v", err)
|
||||
|
||||
rules := ret.([]internalRules.Rule)
|
||||
|
||||
Assert(t, len(rules) == 7, "expected len(rules) == 7, got %v", len(rules))
|
||||
}
|
||||
|
||||
// Delete rules for test role.
|
||||
{
|
||||
_, err := rulesSvc.Delete(role.ID)
|
||||
NoError(t, err, "expected no error, got %+v", err)
|
||||
}
|
||||
|
||||
// Read rules for test role.
|
||||
{
|
||||
ret, err := rulesSvc.Read(role.ID)
|
||||
NoError(t, err, "expected no error, got %+v", err)
|
||||
|
||||
rules := ret.([]internalRules.Rule)
|
||||
|
||||
Assert(t, len(rules) == 0, "expected len(rules) == 0, got %v", len(rules))
|
||||
}
|
||||
|
||||
// List possible permissions with no grants.
|
||||
{
|
||||
ret, err := rulesSvc.List()
|
||||
NoError(t, err, "expected no error, got %+v", err)
|
||||
|
||||
perms := ret.([]types.Permission)
|
||||
|
||||
Assert(t, len(perms) == 0, "expected len(rules) == 0, got %v", len(perms))
|
||||
}
|
||||
return errors.New("Rollback")
|
||||
}), "expected rollback error")
|
||||
}
|
||||
34
system/internal/service/service.go
Normal file
34
system/internal/service/service.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type (
|
||||
db interface {
|
||||
Transaction(callback func() error) error
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
o sync.Once
|
||||
DefaultAuth AuthService
|
||||
DefaultUser UserService
|
||||
DefaultRole RoleService
|
||||
DefaultRules RulesService
|
||||
DefaultOrganisation OrganisationService
|
||||
DefaultApplication ApplicationService
|
||||
DefaultPermissions PermissionsService
|
||||
)
|
||||
|
||||
func init() {
|
||||
o.Do(func() {
|
||||
DefaultRules = Rules()
|
||||
DefaultPermissions = Permissions()
|
||||
DefaultAuth = Auth()
|
||||
DefaultUser = User()
|
||||
DefaultRole = Role()
|
||||
DefaultOrganisation = Organisation()
|
||||
DefaultApplication = Application()
|
||||
})
|
||||
}
|
||||
176
system/internal/service/user.go
Normal file
176
system/internal/service/user.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/system/internal/repository"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrUserInvalidCredentials = serviceError("UserInvalidCredentials")
|
||||
ErrUserLocked = serviceError("UserLocked")
|
||||
|
||||
uuidLength = 36
|
||||
)
|
||||
|
||||
type (
|
||||
user struct {
|
||||
db *factory.DB
|
||||
ctx context.Context
|
||||
|
||||
user repository.UserRepository
|
||||
}
|
||||
|
||||
UserService interface {
|
||||
With(ctx context.Context) UserService
|
||||
|
||||
FindByUsername(username string) (*types.User, error)
|
||||
FindByEmail(email string) (*types.User, error)
|
||||
FindByID(id uint64) (*types.User, error)
|
||||
FindByIDs(id ...uint64) (types.UserSet, error)
|
||||
Find(filter *types.UserFilter) (types.UserSet, error)
|
||||
|
||||
FindOrCreate(*types.User) (*types.User, error)
|
||||
|
||||
Create(input *types.User) (*types.User, error)
|
||||
Update(mod *types.User) (*types.User, error)
|
||||
|
||||
Delete(id uint64) error
|
||||
Suspend(id uint64) error
|
||||
Unsuspend(id uint64) error
|
||||
|
||||
ValidateCredentials(username, password string) (*types.User, error)
|
||||
}
|
||||
)
|
||||
|
||||
func User() UserService {
|
||||
return (&user{}).With(context.Background())
|
||||
}
|
||||
|
||||
func (svc *user) With(ctx context.Context) UserService {
|
||||
db := repository.DB(ctx)
|
||||
|
||||
return &user{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
user: repository.User(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *user) Delete(id uint64) error {
|
||||
return svc.user.DeleteByID(id)
|
||||
}
|
||||
|
||||
func (svc *user) Suspend(id uint64) error {
|
||||
return svc.user.SuspendByID(id)
|
||||
}
|
||||
|
||||
func (svc *user) Unsuspend(id uint64) error {
|
||||
return svc.user.UnsuspendByID(id)
|
||||
}
|
||||
|
||||
func (svc *user) ValidateCredentials(username, password string) (*types.User, error) {
|
||||
user, err := svc.user.FindByUsername(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.ValidatePassword(password) {
|
||||
return nil, ErrUserInvalidCredentials
|
||||
}
|
||||
|
||||
if !svc.canLogin(user) {
|
||||
return nil, ErrUserLocked
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (svc *user) FindByID(id uint64) (*types.User, error) {
|
||||
return svc.user.FindByID(id)
|
||||
}
|
||||
|
||||
func (svc *user) FindByIDs(ids ...uint64) (types.UserSet, error) {
|
||||
return svc.user.FindByIDs(ids...)
|
||||
}
|
||||
|
||||
func (svc *user) FindByEmail(email string) (*types.User, error) {
|
||||
return svc.user.FindByEmail(email)
|
||||
}
|
||||
|
||||
func (svc *user) FindByUsername(username string) (*types.User, error) {
|
||||
return svc.user.FindByUsername(username)
|
||||
}
|
||||
|
||||
func (svc *user) Find(filter *types.UserFilter) (types.UserSet, error) {
|
||||
return svc.user.Find(filter)
|
||||
}
|
||||
|
||||
// Finds if user with a specific satosa id exists and returns it otherwise it creates a fresh one
|
||||
func (svc *user) FindOrCreate(user *types.User) (out *types.User, err error) {
|
||||
return out, svc.db.Transaction(func() error {
|
||||
if len(user.SatosaID) != uuidLength {
|
||||
// @todo uuid format check
|
||||
return errors.Errorf("Invalid UUID value (%v) for SATOSA ID", user.SatosaID)
|
||||
}
|
||||
|
||||
out, err = svc.user.FindBySatosaID(user.SatosaID)
|
||||
|
||||
if err == repository.ErrUserNotFound {
|
||||
out, err = svc.user.Create(user)
|
||||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// FindBySatosaID error
|
||||
return err
|
||||
}
|
||||
|
||||
// @todo need to be more selective with fields we update...
|
||||
out, err = svc.user.Update(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (svc *user) Create(input *types.User) (out *types.User, err error) {
|
||||
return out, svc.db.Transaction(func() error {
|
||||
// Encrypt user password
|
||||
if out, err = svc.user.Create(input); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (svc *user) Update(mod *types.User) (u *types.User, err error) {
|
||||
return u, svc.db.Transaction(func() (err error) {
|
||||
if u, err = svc.user.FindByID(mod.ID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Assign changed values
|
||||
u.Email = mod.Email
|
||||
u.Username = mod.Username
|
||||
u.Name = mod.Name
|
||||
u.Handle = mod.Handle
|
||||
u.Kind = mod.Kind
|
||||
|
||||
if u, err = svc.user.Update(u); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (svc *user) canLogin(u *types.User) bool {
|
||||
return u != nil && u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil
|
||||
}
|
||||
204
system/internal/service/user_mock_test.go
Normal file
204
system/internal/service/user_mock_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: system/internal/service/user.go
|
||||
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
context "context"
|
||||
types "github.com/crusttech/crust/system/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockUserService is a mock of UserService interface
|
||||
type MockUserService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUserServiceMockRecorder
|
||||
}
|
||||
|
||||
// MockUserServiceMockRecorder is the mock recorder for MockUserService
|
||||
type MockUserServiceMockRecorder struct {
|
||||
mock *MockUserService
|
||||
}
|
||||
|
||||
// NewMockUserService creates a new mock instance
|
||||
func NewMockUserService(ctrl *gomock.Controller) *MockUserService {
|
||||
mock := &MockUserService{ctrl: ctrl}
|
||||
mock.recorder = &MockUserServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockUserService) EXPECT() *MockUserServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// With mocks base method
|
||||
func (m *MockUserService) With(ctx context.Context) UserService {
|
||||
ret := m.ctrl.Call(m, "With", ctx)
|
||||
ret0, _ := ret[0].(UserService)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// With indicates an expected call of With
|
||||
func (mr *MockUserServiceMockRecorder) With(ctx interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockUserService)(nil).With), ctx)
|
||||
}
|
||||
|
||||
// FindByUsername mocks base method
|
||||
func (m *MockUserService) FindByUsername(username string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByUsername", username)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByUsername indicates an expected call of FindByUsername
|
||||
func (mr *MockUserServiceMockRecorder) FindByUsername(username interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByUsername", reflect.TypeOf((*MockUserService)(nil).FindByUsername), username)
|
||||
}
|
||||
|
||||
// FindByEmail mocks base method
|
||||
func (m *MockUserService) FindByEmail(email string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByEmail", email)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByEmail indicates an expected call of FindByEmail
|
||||
func (mr *MockUserServiceMockRecorder) FindByEmail(email interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByEmail", reflect.TypeOf((*MockUserService)(nil).FindByEmail), email)
|
||||
}
|
||||
|
||||
// FindByID mocks base method
|
||||
func (m *MockUserService) FindByID(id uint64) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindByID", id)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByID indicates an expected call of FindByID
|
||||
func (mr *MockUserServiceMockRecorder) FindByID(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByID", reflect.TypeOf((*MockUserService)(nil).FindByID), id)
|
||||
}
|
||||
|
||||
// FindByIDs mocks base method
|
||||
func (m *MockUserService) FindByIDs(id ...uint64) (types.UserSet, error) {
|
||||
varargs := []interface{}{}
|
||||
for _, a := range id {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindByIDs", varargs...)
|
||||
ret0, _ := ret[0].(types.UserSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByIDs indicates an expected call of FindByIDs
|
||||
func (mr *MockUserServiceMockRecorder) FindByIDs(id ...interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByIDs", reflect.TypeOf((*MockUserService)(nil).FindByIDs), id...)
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockUserService) Find(filter *types.UserFilter) (types.UserSet, error) {
|
||||
ret := m.ctrl.Call(m, "Find", filter)
|
||||
ret0, _ := ret[0].(types.UserSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockUserServiceMockRecorder) Find(filter interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockUserService)(nil).Find), filter)
|
||||
}
|
||||
|
||||
// FindOrCreate mocks base method
|
||||
func (m *MockUserService) FindOrCreate(arg0 *types.User) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "FindOrCreate", arg0)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindOrCreate indicates an expected call of FindOrCreate
|
||||
func (mr *MockUserServiceMockRecorder) FindOrCreate(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOrCreate", reflect.TypeOf((*MockUserService)(nil).FindOrCreate), arg0)
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockUserService) Create(input *types.User) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Create", input)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create
|
||||
func (mr *MockUserServiceMockRecorder) Create(input interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockUserService)(nil).Create), input)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockUserService) Update(mod *types.User) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "Update", mod)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockUserServiceMockRecorder) Update(mod interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockUserService)(nil).Update), mod)
|
||||
}
|
||||
|
||||
// Delete mocks base method
|
||||
func (m *MockUserService) Delete(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "Delete", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete
|
||||
func (mr *MockUserServiceMockRecorder) Delete(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockUserService)(nil).Delete), id)
|
||||
}
|
||||
|
||||
// Suspend mocks base method
|
||||
func (m *MockUserService) Suspend(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "Suspend", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Suspend indicates an expected call of Suspend
|
||||
func (mr *MockUserServiceMockRecorder) Suspend(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Suspend", reflect.TypeOf((*MockUserService)(nil).Suspend), id)
|
||||
}
|
||||
|
||||
// Unsuspend mocks base method
|
||||
func (m *MockUserService) Unsuspend(id uint64) error {
|
||||
ret := m.ctrl.Call(m, "Unsuspend", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Unsuspend indicates an expected call of Unsuspend
|
||||
func (mr *MockUserServiceMockRecorder) Unsuspend(id interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsuspend", reflect.TypeOf((*MockUserService)(nil).Unsuspend), id)
|
||||
}
|
||||
|
||||
// ValidateCredentials mocks base method
|
||||
func (m *MockUserService) ValidateCredentials(username, password string) (*types.User, error) {
|
||||
ret := m.ctrl.Call(m, "ValidateCredentials", username, password)
|
||||
ret0, _ := ret[0].(*types.User)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ValidateCredentials indicates an expected call of ValidateCredentials
|
||||
func (mr *MockUserServiceMockRecorder) ValidateCredentials(username, password interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateCredentials", reflect.TypeOf((*MockUserService)(nil).ValidateCredentials), username, password)
|
||||
}
|
||||
31
system/internal/service/user_test.go
Normal file
31
system/internal/service/user_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package service
|
||||
|
||||
// func TestUser(t *testing.T) {
|
||||
// mockCtrl := gomock.NewController(t)
|
||||
// defer mockCtrl.Finish()
|
||||
//
|
||||
// usr := &types.Owner{ID: factory.Sonyflake.NextID()}
|
||||
//
|
||||
// usrRpoMock := NewMockRepository(mockCtrl)
|
||||
// usrRpoMock.EXPECT().WithCtx(gomock.Any()).AnyTimes().Return(usrRpoMock)
|
||||
// usrRpoMock.EXPECT().
|
||||
// FindUserByID(usr.ID).
|
||||
// Times(1).
|
||||
// Return(usr, nil)
|
||||
//
|
||||
// svc := Owner()
|
||||
// svc.rpo = usrRpoMock
|
||||
//
|
||||
// found, err := svc.FindByID(context.Background(), usr.ID)
|
||||
// if err != nil {
|
||||
// t.Fatal("Did not expect an error")
|
||||
// }
|
||||
//
|
||||
// if found == nil {
|
||||
// t.Fatal("Expecting an user to be found")
|
||||
// }
|
||||
//
|
||||
// if found.ID != usr.ID {
|
||||
// t.Fatal("Expecting found user to have the same ID as the find param")
|
||||
// }
|
||||
// }
|
||||
126
system/internal/service/validation.go
Normal file
126
system/internal/service/validation.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
permissionList = map[string]map[string]bool{
|
||||
"system": map[string]bool{
|
||||
"access": true,
|
||||
"grant": true,
|
||||
"organisation.create": true,
|
||||
"role.create": true,
|
||||
"application.create": true,
|
||||
},
|
||||
"system:organisation": map[string]bool{
|
||||
"access": true,
|
||||
},
|
||||
"system:role": map[string]bool{
|
||||
"read": true,
|
||||
"update": true,
|
||||
"delete": true,
|
||||
"members.manage": true,
|
||||
},
|
||||
"system:application": map[string]bool{
|
||||
"read": true,
|
||||
"update": true,
|
||||
"delete": true,
|
||||
},
|
||||
"messaging": map[string]bool{
|
||||
"access": true,
|
||||
"grant": true,
|
||||
"channel.public.create": true,
|
||||
"channel.private.create": true,
|
||||
"channel.group.create": true,
|
||||
},
|
||||
"messaging:channel": map[string]bool{
|
||||
"update": true,
|
||||
"read": true,
|
||||
"join": true,
|
||||
"leave": true,
|
||||
"delete": true,
|
||||
"undelete": true,
|
||||
"archive": true,
|
||||
"unarchive": true,
|
||||
"members.manage": true,
|
||||
"webhooks.manage": true,
|
||||
"attachments.manage": true,
|
||||
"message.send": true,
|
||||
"message.reply": true,
|
||||
"message.embed": true,
|
||||
"message.attach": true,
|
||||
"message.update.own": true,
|
||||
"message.update.all": true,
|
||||
"message.delete.own": true,
|
||||
"message.delete.all": true,
|
||||
"message.react": true,
|
||||
},
|
||||
"compose": map[string]bool{
|
||||
"access": true,
|
||||
"grant": true,
|
||||
"namespace.create": true,
|
||||
},
|
||||
"compose:namespace": map[string]bool{
|
||||
"read": true,
|
||||
"update": true,
|
||||
"delete": true,
|
||||
"module.create": true,
|
||||
"chart.create": true,
|
||||
"trigger.create": true,
|
||||
"page.create": true,
|
||||
},
|
||||
"compose:module": map[string]bool{
|
||||
"read": true,
|
||||
"update": true,
|
||||
"delete": true,
|
||||
"record.create": true,
|
||||
"record.read": true,
|
||||
"record.update": true,
|
||||
"record.delete": true,
|
||||
},
|
||||
"compose:chart": map[string]bool{
|
||||
"read": true,
|
||||
"update": true,
|
||||
"delete": true,
|
||||
},
|
||||
"compose:trigger": map[string]bool{
|
||||
"read": true,
|
||||
"update": true,
|
||||
"delete": true,
|
||||
},
|
||||
"compose:page": map[string]bool{
|
||||
"read": true,
|
||||
"update": true,
|
||||
"delete": true,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func validatePermission(resource string, operation string) error {
|
||||
resourceParts := strings.Split(resource, delimiter)
|
||||
if len(resourceParts) < 1 {
|
||||
return errors.Errorf("Invalid resource format, expected >= 1, got %d", len(resourceParts))
|
||||
}
|
||||
|
||||
resourceName := resourceParts[0]
|
||||
if len(resourceParts) > 1 {
|
||||
resourceName = resourceParts[0] + delimiter + resourceParts[1]
|
||||
}
|
||||
|
||||
if service, ok := permissionList[resourceName]; ok {
|
||||
if op := service[operation]; op {
|
||||
if len(resourceParts) == 3 {
|
||||
if val := resourceParts[2]; val != "" {
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("Invalid resource format, missing resource ID")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("Unknown operation: '%s'", operation)
|
||||
}
|
||||
return errors.Errorf("Unknown resource name: '%s'", resourceName)
|
||||
}
|
||||
15
system/internal/service/validation_test.go
Normal file
15
system/internal/service/validation_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/crusttech/crust/internal/test"
|
||||
)
|
||||
|
||||
func TestPermissionsValidation(t *testing.T) {
|
||||
test.Error(t, validatePermission("bogus", "bogus"), "expected error")
|
||||
test.Error(t, validatePermission("bogus", "bogus"), "expected error")
|
||||
test.Error(t, validatePermission("messaging:channel", "bogus"), "expected error")
|
||||
test.Error(t, validatePermission("messaging:channel:", "message.send"), "expected error")
|
||||
test.NoError(t, validatePermission("messaging:channel:1", "message.send"), "expected valid response")
|
||||
}
|
||||
Reference in New Issue
Block a user