3
0

Add permission checking for all CRM elements

This commit is contained in:
Denis Arh
2019-03-12 00:19:21 +01:00
parent c6bb0afc9f
commit 61b94f620e
7 changed files with 284 additions and 104 deletions

View File

@@ -34,7 +34,12 @@ type (
ctx context.Context
store store.Store
usr systemService.UserService
prmSvc PermissionsService
pageSvc PageService
moduleSvc ModuleService
recordSvc RecordService
usr systemService.UserService
attachment repository.AttachmentRepository
}
@@ -53,8 +58,12 @@ type (
func Attachment(store store.Store) AttachmentService {
return (&attachment{
store: store,
usr: systemService.DefaultUser,
store: store,
prmSvc: DefaultPermissions,
pageSvc: DefaultPage,
moduleSvc: DefaultModule,
recordSvc: DefaultRecord,
usr: systemService.DefaultUser,
}).With(context.Background())
}
@@ -64,20 +73,40 @@ func (svc *attachment) With(ctx context.Context) AttachmentService {
db: db,
ctx: ctx,
store: svc.store,
usr: svc.usr.With(ctx),
prmSvc: svc.prmSvc.With(ctx),
pageSvc: svc.pageSvc.With(ctx),
moduleSvc: svc.moduleSvc.With(ctx),
recordSvc: svc.recordSvc.With(ctx),
usr: svc.usr.With(ctx),
store: svc.store,
attachment: repository.Attachment(ctx, db),
}
}
func (svc *attachment) FindByID(id uint64) (*types.Attachment, error) {
// @todo [SECURITY] check if record/page can be accessed
return svc.attachment.FindByID(id)
}
func (svc *attachment) Find(filter types.AttachmentFilter) (types.AttachmentSet, types.AttachmentFilter, error) {
// @todo [SECURITY] enforce filter combination (page / module+record+field) & check access
if filter.PageID > 0 {
if _, err := svc.pageSvc.FindByID(filter.PageID); err != nil {
return nil, filter, err
}
}
if filter.ModuleID > 0 {
if _, err := svc.moduleSvc.FindByID(filter.ModuleID); err != nil {
return nil, filter, err
}
}
if filter.RecordID > 0 {
if _, err := svc.recordSvc.FindByID(filter.RecordID); err != nil {
return nil, filter, err
}
}
return svc.attachment.Find(filter)
}
@@ -100,8 +129,11 @@ func (svc *attachment) OpenPreview(att *types.Attachment) (io.ReadSeeker, error)
func (svc *attachment) CreatePageAttachment(name string, size int64, fh io.ReadSeeker, pageID uint64) (*types.Attachment, error) {
var currentUserID uint64 = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo verify if current user can access this page
// @todo verify if current user can upload to this page
if p, err := svc.pageSvc.FindByID(pageID); err != nil {
return nil, err
} else if !svc.prmSvc.CanUpdatePage(p) {
return nil, errors.New("not allowed to add attachments to this page")
}
att := &types.Attachment{
ID: factory.Sonyflake.NextID(),
@@ -115,8 +147,13 @@ func (svc *attachment) CreatePageAttachment(name string, size int64, fh io.ReadS
func (svc *attachment) CreateRecordAttachment(name string, size int64, fh io.ReadSeeker, moduleID, recordID uint64, fieldName string) (*types.Attachment, error) {
var currentUserID uint64 = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo verify if current user can access this record
// @todo verify if current user can upload to this record
if _, err := svc.moduleSvc.FindByID(moduleID); err != nil {
return nil, err
} else if r, err := svc.recordSvc.FindByID(recordID); err != nil {
return nil, err
} else if !svc.prmSvc.CanUpdateRecord(r) {
return nil, errors.New("not allowed to add attachments to this record")
}
att := &types.Attachment{
ID: factory.Sonyflake.NextID(),

View File

@@ -15,6 +15,8 @@ type (
db *factory.DB
ctx context.Context
prmSvc PermissionsService
chartRepo repository.ChartRepository
}
@@ -31,27 +33,48 @@ type (
)
func Chart() ChartService {
return (&chart{}).With(context.Background())
return (&chart{
prmSvc: DefaultPermissions,
}).With(context.Background())
}
func (svc *chart) With(ctx context.Context) ChartService {
db := repository.DB(ctx)
return &chart{
db: db,
ctx: ctx,
db: db,
ctx: ctx,
prmSvc: svc.prmSvc.With(ctx),
chartRepo: repository.Chart(ctx, db),
}
}
func (svc *chart) FindByID(chartID uint64) (*types.Chart, error) {
return svc.chartRepo.FindByID(chartID)
func (svc *chart) FindByID(chartID uint64) (c *types.Chart, err error) {
if c, err = svc.chartRepo.FindByID(chartID); err != nil {
return
} else if !svc.prmSvc.CanReadChart(c) {
return nil, errors.New("not allowed to access this chart")
}
return
}
func (svc *chart) Find() (types.ChartSet, error) {
return svc.chartRepo.Find()
func (svc *chart) Find() (cc types.ChartSet, err error) {
if cc, err = svc.chartRepo.Find(); err != nil {
return nil, err
} else {
return cc.Filter(func(m *types.Chart) (bool, error) {
return svc.prmSvc.CanReadChart(m), nil
})
}
}
func (svc *chart) Create(mod *types.Chart) (c *types.Chart, err error) {
if !svc.prmSvc.CanCreateChart() {
return nil, errors.New("not allowed to create this chart")
}
return c, svc.db.Transaction(func() error {
c, err = svc.chartRepo.Create(mod)
return err
@@ -65,6 +88,10 @@ func (svc *chart) Update(mod *types.Chart) (c *types.Chart, err error) {
} else if c, err = svc.chartRepo.FindByID(mod.ID); err != nil {
return errors.Wrap(err, "Error while loading chart for update")
} else {
if !svc.prmSvc.CanUpdateChart(c) {
return errors.New("not allowed to update this chart")
}
mod.CreatedAt = c.CreatedAt
}
@@ -84,6 +111,10 @@ func (svc *chart) Update(mod *types.Chart) (c *types.Chart, err error) {
})
}
func (svc *chart) DeleteByID(chartID uint64) error {
return svc.chartRepo.DeleteByID(chartID)
func (svc *chart) DeleteByID(ID uint64) error {
if !svc.prmSvc.CanDeleteChartByID(ID) {
return errors.New("not allowed to delete this chart")
}
return svc.chartRepo.DeleteByID(ID)
}

View File

@@ -15,6 +15,8 @@ type (
db *factory.DB
ctx context.Context
prmSvc PermissionsService
moduleRepo repository.ModuleRepository
pageRepo repository.PageRepository
}
@@ -23,7 +25,7 @@ type (
With(ctx context.Context) ModuleService
FindByID(moduleID uint64) (*types.Module, error)
Find() ([]*types.Module, error)
Find() (types.ModuleSet, error)
Create(module *types.Module) (*types.Module, error)
Update(module *types.Module) (*types.Module, error)
@@ -32,46 +34,67 @@ type (
)
func Module() ModuleService {
return (&module{}).With(context.Background())
return (&module{
prmSvc: DefaultPermissions,
}).With(context.Background())
}
func (s *module) With(ctx context.Context) ModuleService {
func (svc *module) With(ctx context.Context) ModuleService {
db := repository.DB(ctx)
return &module{
db: db,
ctx: ctx,
db: db,
ctx: ctx,
prmSvc: svc.prmSvc.With(ctx),
moduleRepo: repository.Module(ctx, db),
pageRepo: repository.Page(ctx, db),
}
}
func (s *module) FindByID(id uint64) (*types.Module, error) {
mod, err := s.moduleRepo.FindByID(id)
if err != nil {
return nil, err
func (svc *module) FindByID(id uint64) (m *types.Module, err error) {
if m, err = svc.moduleRepo.FindByID(id); err != nil {
return
} else if !svc.prmSvc.CanReadModule(m) {
return nil, errors.New("not allowed to access this module")
}
return mod, err
return
}
func (s *module) Find() ([]*types.Module, error) {
return s.moduleRepo.Find()
func (svc *module) Find() (mm types.ModuleSet, err error) {
if mm, err = svc.moduleRepo.Find(); err != nil {
return nil, err
} else {
return mm.Filter(func(m *types.Module) (bool, error) {
return svc.prmSvc.CanReadModule(m), nil
})
}
}
func (s *module) Create(mod *types.Module) (*types.Module, error) {
func (svc *module) Create(mod *types.Module) (*types.Module, error) {
if !svc.prmSvc.CanCreateModule() {
return nil, errors.New("not allowed to create this module")
}
if len(mod.Fields) == 0 {
return nil, errors.New("Error creating module: no fields")
}
return s.moduleRepo.Create(mod)
return svc.moduleRepo.Create(mod)
}
func (s *module) Update(module *types.Module) (m *types.Module, err error) {
func (svc *module) Update(module *types.Module) (m *types.Module, err error) {
validate := func() error {
if module.ID == 0 {
return errors.New("Error updating module: invalid ID")
} else if m, err = s.moduleRepo.FindByID(module.ID); err != nil {
} else if m, err = svc.moduleRepo.FindByID(module.ID); err != nil {
return errors.Wrap(err, "Error while loading module for update")
} else {
if !svc.prmSvc.CanUpdateModule(m) {
return errors.New("not allowed to update this module")
}
module.CreatedAt = m.CreatedAt
}
@@ -86,12 +109,16 @@ func (s *module) Update(module *types.Module) (m *types.Module, err error) {
return nil, err
}
return m, s.db.Transaction(func() (err error) {
m, err = s.moduleRepo.Update(module)
return m, svc.db.Transaction(func() (err error) {
m, err = svc.moduleRepo.Update(module)
return
})
}
func (s *module) DeleteByID(id uint64) error {
return s.moduleRepo.DeleteByID(id)
func (svc *module) DeleteByID(ID uint64) error {
if !svc.prmSvc.CanDeleteModuleByID(ID) {
return errors.New("not allowed to delete this module")
}
return svc.moduleRepo.DeleteByID(ID)
}

View File

@@ -2,8 +2,8 @@ package service
import (
"context"
"errors"
"github.com/pkg/errors"
"github.com/titpetric/factory"
"github.com/crusttech/crust/crm/repository"
@@ -15,6 +15,8 @@ type (
db *factory.DB
ctx context.Context
prmSvc PermissionsService
pageRepo repository.PageRepository
moduleRepo repository.ModuleRepository
}
@@ -38,45 +40,59 @@ type (
)
func Page() PageService {
return (&page{}).With(context.Background())
return (&page{
prmSvc: DefaultPermissions,
}).With(context.Background())
}
func (s *page) With(ctx context.Context) PageService {
func (svc *page) With(ctx context.Context) PageService {
db := repository.DB(ctx)
return &page{
db: db,
ctx: ctx,
db: db,
ctx: ctx,
prmSvc: svc.prmSvc.With(ctx),
pageRepo: repository.Page(ctx, db),
moduleRepo: repository.Module(ctx, db),
}
}
func (s *page) FindByID(id uint64) (*types.Page, error) {
return s.pageRepo.FindByID(id)
func (svc *page) FindByID(id uint64) (p *types.Page, err error) {
return svc.checkPermissions(svc.pageRepo.FindByID(id))
}
func (s *page) FindByModuleID(moduleID uint64) (*types.Page, error) {
return s.pageRepo.FindByModuleID(moduleID)
func (svc *page) FindByModuleID(moduleID uint64) (p *types.Page, err error) {
return svc.checkPermissions(svc.pageRepo.FindByModuleID(moduleID))
}
func (s *page) FindBySelfID(selfID uint64) (pages types.PageSet, err error) {
return s.pageRepo.FindBySelfID(selfID)
func (svc *page) checkPermissions(p *types.Page, err error) (*types.Page, error) {
if err != nil {
return nil, err
} else if !svc.prmSvc.CanReadPage(p) {
return nil, errors.New("not allowed to access this page")
}
return p, err
}
func (s *page) Find() (pages types.PageSet, err error) {
return s.pageRepo.Find()
func (svc *page) FindBySelfID(selfID uint64) (pp types.PageSet, err error) {
return svc.filterPageSet(svc.pageRepo.FindBySelfID(selfID))
}
func (s *page) Tree() (pages types.PageSet, err error) {
func (svc *page) Find() (pages types.PageSet, err error) {
return svc.filterPageSet(svc.pageRepo.Find())
}
func (svc *page) Tree() (pages types.PageSet, err error) {
var tree types.PageSet
return tree, s.db.Transaction(func() (err error) {
if pages, err = s.pageRepo.Find(); err != nil {
return tree, svc.db.Transaction(func() (err error) {
if pages, err = svc.filterPageSet(svc.pageRepo.Find()); err != nil {
return
}
// No preloading - we do not need (or should have) any modules
// associated with us
_ = pages.Walk(func(p *types.Page) error {
if p.SelfID == 0 {
@@ -100,19 +116,32 @@ func (s *page) Tree() (pages types.PageSet, err error) {
})
}
func (s *page) FindRecordPages() (pages types.PageSet, err error) {
return s.pageRepo.FindRecordPages()
func (svc *page) FindRecordPages() (pages types.PageSet, err error) {
return svc.pageRepo.FindRecordPages()
}
func (s *page) Reorder(selfID uint64, pageIDs []uint64) error {
return s.pageRepo.Reorder(selfID, pageIDs)
func (svc *page) filterPageSet(pp types.PageSet, err error) (types.PageSet, error) {
if err != nil {
return nil, err
}
return pp.Filter(func(m *types.Page) (bool, error) {
return svc.prmSvc.CanReadPage(m), nil
})
}
func (s *page) Create(page *types.Page) (p *types.Page, err error) {
func (svc *page) Reorder(selfID uint64, pageIDs []uint64) error {
return svc.pageRepo.Reorder(selfID, pageIDs)
}
func (svc *page) Create(page *types.Page) (p *types.Page, err error) {
validate := func() error {
if !svc.prmSvc.CanCreatePage() {
return errors.New("not allowed to create this module")
}
if page.ModuleID > 0 {
// @todo check if module exists!
if p, err = s.pageRepo.FindByModuleID(page.ModuleID); err != nil {
if p, err = svc.pageRepo.FindByModuleID(page.ModuleID); err != nil {
return err
} else if p.ID > 0 {
return errors.New("Page for module already exists")
@@ -123,20 +152,26 @@ func (s *page) Create(page *types.Page) (p *types.Page, err error) {
if err := validate(); err != nil {
return nil, err
}
return p, s.db.Transaction(func() (err error) {
p, err = s.pageRepo.Create(page)
return p, svc.db.Transaction(func() (err error) {
p, err = svc.pageRepo.Create(page)
return
})
}
func (s *page) Update(page *types.Page) (p *types.Page, err error) {
func (svc *page) Update(page *types.Page) (p *types.Page, err error) {
validate := func() error {
if page.ID == 0 {
return errors.New("Error when savig page, invalid ID")
} else if p, err = svc.pageRepo.FindByID(page.ID); err != nil {
return errors.Wrap(err, "Error while loading module for update")
} else {
if !svc.prmSvc.CanUpdatePage(p) {
return errors.New("not allowed to update this pahe")
}
}
if page.ModuleID > 0 {
// @todo check if module exists!
if p, err = s.pageRepo.FindByModuleID(page.ModuleID); err != nil {
if p, err = svc.pageRepo.FindByModuleID(page.ModuleID); err != nil {
return err
} else if p.ID > 0 && page.ID != p.ID {
return errors.New("Page for module already exists")
@@ -147,12 +182,16 @@ func (s *page) Update(page *types.Page) (p *types.Page, err error) {
if err := validate(); err != nil {
return nil, err
}
return p, s.db.Transaction(func() (err error) {
p, err = s.pageRepo.Update(page)
return p, svc.db.Transaction(func() (err error) {
p, err = svc.pageRepo.Update(page)
return
})
}
func (s *page) DeleteByID(id uint64) error {
return s.pageRepo.DeleteByID(id)
func (svc *page) DeleteByID(ID uint64) error {
if !svc.prmSvc.CanDeletePageByID(ID) {
return errors.New("not allowed to delete this page")
}
return svc.pageRepo.DeleteByID(ID)
}

View File

@@ -20,10 +20,11 @@ type (
db *factory.DB
ctx context.Context
prmSvc PermissionsService
userSvc systemService.UserService
repository repository.RecordRepository
moduleRepo repository.ModuleRepository
userSvc systemService.UserService
}
RecordService interface {
@@ -45,6 +46,7 @@ type (
func Record() RecordService {
return (&record{
prmSvc: DefaultPermissions,
userSvc: systemService.DefaultUser,
}).With(context.Background())
}
@@ -52,11 +54,14 @@ func Record() RecordService {
func (svc *record) With(ctx context.Context) RecordService {
db := repository.DB(ctx)
return &record{
db: db,
ctx: ctx,
db: db,
ctx: ctx,
prmSvc: svc.prmSvc.With(ctx),
userSvc: svc.userSvc.With(ctx),
repository: repository.Record(ctx, db),
moduleRepo: repository.Module(ctx, db),
userSvc: svc.userSvc.With(ctx),
}
}
@@ -64,6 +69,8 @@ func (svc *record) FindByID(recordID uint64) (r *types.Record, err error) {
err = svc.db.Transaction(func() (err error) {
if r, err = svc.repository.FindByID(recordID); err != nil {
return
} else if !svc.prmSvc.CanReadRecord(r) {
return errors.New("not allowed to access this record")
}
if err = svc.preloadValues(r); err != nil {
@@ -82,6 +89,8 @@ func (svc *record) Report(moduleID uint64, metrics, dimensions, filter string) (
err = svc.db.Transaction(func() (err error) {
if module, err = svc.moduleRepo.FindByID(moduleID); err != nil {
return
} else if !svc.prmSvc.CanReadRecord(module) {
return errors.New("not allowed to access this record")
}
out, err = svc.repository.Report(module, metrics, dimensions, filter)
@@ -97,6 +106,8 @@ func (svc *record) Find(moduleID uint64, filter, sort string, page, perPage int)
err = svc.db.Transaction(func() (err error) {
if module, err = svc.moduleRepo.FindByID(moduleID); err != nil {
return
} else if !svc.prmSvc.CanReadRecord(module) {
return errors.New("not allowed to access this record")
}
if rsp, err = svc.repository.Find(module, filter, sort, page, perPage); err != nil {
@@ -120,6 +131,8 @@ func (svc *record) Create(in *types.Record) (record *types.Record, err error) {
err = svc.db.Transaction(func() (err error) {
if module, err = svc.moduleRepo.FindByID(in.ModuleID); err != nil {
return
} else if !svc.prmSvc.CanCreateRecord(module) {
return errors.New("not allowed to create records for this module")
}
if err = svc.sanitizeValues(module, in.Values); err != nil {
@@ -157,6 +170,8 @@ func (svc *record) Update(updated *types.Record) (record *types.Record, err erro
if record, err = svc.repository.FindByID(updated.ID); err != nil {
return errors.Wrap(err, "nonexistent record")
} else if !svc.prmSvc.CanUpdateRecord(record) {
return errors.New("not allowed to update this record")
}
if module, err = svc.moduleRepo.FindByID(updated.ModuleID); err != nil {

View File

@@ -32,13 +32,13 @@ func Init() {
log.Fatalf("Failed to initialize store: %v", err)
}
DefaultPermissions = Permissions()
DefaultRecord = Record()
DefaultModule = Module()
DefaultTrigger = Trigger()
DefaultPage = Page()
DefaultChart = Chart()
DefaultNotification = Notification()
DefaultPermissions = Permissions()
DefaultAttachment = Attachment(fs)
})
}

View File

@@ -2,8 +2,8 @@ package service
import (
"context"
"errors"
"github.com/pkg/errors"
"github.com/titpetric/factory"
"github.com/crusttech/crust/crm/repository"
@@ -15,6 +15,8 @@ type (
db *factory.DB
ctx context.Context
prmSvc PermissionsService
triggerRepo repository.TriggerRepository
moduleRepo repository.ModuleRepository
}
@@ -32,56 +34,85 @@ type (
)
func Trigger() TriggerService {
return (&trigger{}).With(context.Background())
return (&trigger{
prmSvc: DefaultPermissions}).With(context.Background())
}
func (s *trigger) With(ctx context.Context) TriggerService {
func (svc *trigger) With(ctx context.Context) TriggerService {
db := repository.DB(ctx)
return &trigger{
db: db,
ctx: ctx,
db: db,
ctx: ctx,
prmSvc: svc.prmSvc.With(ctx),
triggerRepo: repository.Trigger(ctx, db),
moduleRepo: repository.Module(ctx, db),
}
}
func (s *trigger) FindByID(id uint64) (*types.Trigger, error) {
return s.triggerRepo.FindByID(id)
}
func (s *trigger) Find(filter types.TriggerFilter) (types.TriggerSet, error) {
return s.triggerRepo.Find(filter)
}
func (s *trigger) Create(trigger *types.Trigger) (p *types.Trigger, err error) {
validate := func() error {
return nil
func (svc *trigger) FindByID(id uint64) (t *types.Trigger, err error) {
if t, err = svc.triggerRepo.FindByID(id); err != nil {
return
} else if !svc.prmSvc.CanReadTrigger(t) {
return nil, errors.New("not allowed to access this trigger")
}
if err := validate(); err != nil {
return
}
func (svc *trigger) Find(filter types.TriggerFilter) (tt types.TriggerSet, err error) {
if tt, err = svc.triggerRepo.Find(filter); err != nil {
return nil, err
} else {
return tt.Filter(func(m *types.Trigger) (bool, error) {
return svc.prmSvc.CanReadTrigger(m), nil
})
}
return p, s.db.Transaction(func() (err error) {
p, err = s.triggerRepo.Create(trigger)
}
func (svc *trigger) Create(trigger *types.Trigger) (p *types.Trigger, err error) {
if !svc.prmSvc.CanCreateTrigger() {
return nil, errors.New("not allowed to create this trigger")
}
return p, svc.db.Transaction(func() (err error) {
p, err = svc.triggerRepo.Create(trigger)
return
})
}
func (s *trigger) Update(trigger *types.Trigger) (p *types.Trigger, err error) {
func (svc *trigger) Update(trigger *types.Trigger) (t *types.Trigger, err error) {
validate := func() error {
if trigger.ID == 0 {
return errors.New("could not update trigger, invalid ID")
return errors.New("Error updating trigger: invalid ID")
} else if t, err = svc.triggerRepo.FindByID(trigger.ID); err != nil {
return errors.Wrap(err, "Error while loading trigger for update")
} else {
if !svc.prmSvc.CanUpdateModule(t) {
return errors.New("not allowed to update this trigger")
}
trigger.CreatedAt = t.CreatedAt
}
return nil
}
if err := validate(); err != nil {
return nil, err
}
return p, s.db.Transaction(func() (err error) {
p, err = s.triggerRepo.Update(trigger)
return t, svc.db.Transaction(func() (err error) {
t, err = svc.triggerRepo.Update(trigger)
return
})
}
func (s *trigger) DeleteByID(id uint64) error {
return s.triggerRepo.DeleteByID(id)
func (svc *trigger) DeleteByID(ID uint64) error {
if !svc.prmSvc.CanDeleteTriggerByID(ID) {
return errors.New("not allowed to delete this trigger")
}
return svc.triggerRepo.DeleteByID(ID)
}