From 61b94f620e4a3d895ff793b419e5ace42df2829d Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Tue, 12 Mar 2019 00:19:21 +0100 Subject: [PATCH] Add permission checking for all CRM elements --- crm/service/attachment.go | 59 +++++++++++++++++---- crm/service/chart.go | 49 ++++++++++++++---- crm/service/module.go | 67 ++++++++++++++++-------- crm/service/page.go | 105 ++++++++++++++++++++++++++------------ crm/service/record.go | 25 +++++++-- crm/service/service.go | 2 +- crm/service/trigger.go | 81 ++++++++++++++++++++--------- 7 files changed, 284 insertions(+), 104 deletions(-) diff --git a/crm/service/attachment.go b/crm/service/attachment.go index 2a1522562..5657d6c3d 100644 --- a/crm/service/attachment.go +++ b/crm/service/attachment.go @@ -34,7 +34,12 @@ type ( ctx context.Context store store.Store - usr systemService.UserService + + prmSvc PermissionsService + pageSvc PageService + moduleSvc ModuleService + recordSvc RecordService + usr systemService.UserService attachment repository.AttachmentRepository } @@ -53,8 +58,12 @@ type ( func Attachment(store store.Store) AttachmentService { return (&attachment{ - store: store, - usr: systemService.DefaultUser, + store: store, + prmSvc: DefaultPermissions, + pageSvc: DefaultPage, + moduleSvc: DefaultModule, + recordSvc: DefaultRecord, + usr: systemService.DefaultUser, }).With(context.Background()) } @@ -64,20 +73,40 @@ func (svc *attachment) With(ctx context.Context) AttachmentService { db: db, ctx: ctx, - store: svc.store, - usr: svc.usr.With(ctx), + prmSvc: svc.prmSvc.With(ctx), + pageSvc: svc.pageSvc.With(ctx), + moduleSvc: svc.moduleSvc.With(ctx), + recordSvc: svc.recordSvc.With(ctx), + usr: svc.usr.With(ctx), + store: svc.store, attachment: repository.Attachment(ctx, db), } } func (svc *attachment) FindByID(id uint64) (*types.Attachment, error) { - // @todo [SECURITY] check if record/page can be accessed return svc.attachment.FindByID(id) } func (svc *attachment) Find(filter types.AttachmentFilter) (types.AttachmentSet, types.AttachmentFilter, error) { - // @todo [SECURITY] enforce filter combination (page / module+record+field) & check access + if filter.PageID > 0 { + if _, err := svc.pageSvc.FindByID(filter.PageID); err != nil { + return nil, filter, err + } + } + + if filter.ModuleID > 0 { + if _, err := svc.moduleSvc.FindByID(filter.ModuleID); err != nil { + return nil, filter, err + } + } + + if filter.RecordID > 0 { + if _, err := svc.recordSvc.FindByID(filter.RecordID); err != nil { + return nil, filter, err + } + } + return svc.attachment.Find(filter) } @@ -100,8 +129,11 @@ func (svc *attachment) OpenPreview(att *types.Attachment) (io.ReadSeeker, error) func (svc *attachment) CreatePageAttachment(name string, size int64, fh io.ReadSeeker, pageID uint64) (*types.Attachment, error) { var currentUserID uint64 = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access this page - // @todo verify if current user can upload to this page + if p, err := svc.pageSvc.FindByID(pageID); err != nil { + return nil, err + } else if !svc.prmSvc.CanUpdatePage(p) { + return nil, errors.New("not allowed to add attachments to this page") + } att := &types.Attachment{ ID: factory.Sonyflake.NextID(), @@ -115,8 +147,13 @@ func (svc *attachment) CreatePageAttachment(name string, size int64, fh io.ReadS func (svc *attachment) CreateRecordAttachment(name string, size int64, fh io.ReadSeeker, moduleID, recordID uint64, fieldName string) (*types.Attachment, error) { var currentUserID uint64 = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access this record - // @todo verify if current user can upload to this record + if _, err := svc.moduleSvc.FindByID(moduleID); err != nil { + return nil, err + } else if r, err := svc.recordSvc.FindByID(recordID); err != nil { + return nil, err + } else if !svc.prmSvc.CanUpdateRecord(r) { + return nil, errors.New("not allowed to add attachments to this record") + } att := &types.Attachment{ ID: factory.Sonyflake.NextID(), diff --git a/crm/service/chart.go b/crm/service/chart.go index 56ebbfafc..76db9553d 100644 --- a/crm/service/chart.go +++ b/crm/service/chart.go @@ -15,6 +15,8 @@ type ( db *factory.DB ctx context.Context + prmSvc PermissionsService + chartRepo repository.ChartRepository } @@ -31,27 +33,48 @@ type ( ) func Chart() ChartService { - return (&chart{}).With(context.Background()) + return (&chart{ + prmSvc: DefaultPermissions, + }).With(context.Background()) } func (svc *chart) With(ctx context.Context) ChartService { db := repository.DB(ctx) return &chart{ - db: db, - ctx: ctx, + db: db, + ctx: ctx, + + prmSvc: svc.prmSvc.With(ctx), + chartRepo: repository.Chart(ctx, db), } } -func (svc *chart) FindByID(chartID uint64) (*types.Chart, error) { - return svc.chartRepo.FindByID(chartID) +func (svc *chart) FindByID(chartID uint64) (c *types.Chart, err error) { + if c, err = svc.chartRepo.FindByID(chartID); err != nil { + return + } else if !svc.prmSvc.CanReadChart(c) { + return nil, errors.New("not allowed to access this chart") + } + + return } -func (svc *chart) Find() (types.ChartSet, error) { - return svc.chartRepo.Find() +func (svc *chart) Find() (cc types.ChartSet, err error) { + if cc, err = svc.chartRepo.Find(); err != nil { + return nil, err + } else { + return cc.Filter(func(m *types.Chart) (bool, error) { + return svc.prmSvc.CanReadChart(m), nil + }) + } } func (svc *chart) Create(mod *types.Chart) (c *types.Chart, err error) { + if !svc.prmSvc.CanCreateChart() { + return nil, errors.New("not allowed to create this chart") + } + return c, svc.db.Transaction(func() error { c, err = svc.chartRepo.Create(mod) return err @@ -65,6 +88,10 @@ func (svc *chart) Update(mod *types.Chart) (c *types.Chart, err error) { } else if c, err = svc.chartRepo.FindByID(mod.ID); err != nil { return errors.Wrap(err, "Error while loading chart for update") } else { + if !svc.prmSvc.CanUpdateChart(c) { + return errors.New("not allowed to update this chart") + } + mod.CreatedAt = c.CreatedAt } @@ -84,6 +111,10 @@ func (svc *chart) Update(mod *types.Chart) (c *types.Chart, err error) { }) } -func (svc *chart) DeleteByID(chartID uint64) error { - return svc.chartRepo.DeleteByID(chartID) +func (svc *chart) DeleteByID(ID uint64) error { + if !svc.prmSvc.CanDeleteChartByID(ID) { + return errors.New("not allowed to delete this chart") + } + + return svc.chartRepo.DeleteByID(ID) } diff --git a/crm/service/module.go b/crm/service/module.go index 35a8ecf0c..c649bb5bf 100644 --- a/crm/service/module.go +++ b/crm/service/module.go @@ -15,6 +15,8 @@ type ( db *factory.DB ctx context.Context + prmSvc PermissionsService + moduleRepo repository.ModuleRepository pageRepo repository.PageRepository } @@ -23,7 +25,7 @@ type ( With(ctx context.Context) ModuleService FindByID(moduleID uint64) (*types.Module, error) - Find() ([]*types.Module, error) + Find() (types.ModuleSet, error) Create(module *types.Module) (*types.Module, error) Update(module *types.Module) (*types.Module, error) @@ -32,46 +34,67 @@ type ( ) func Module() ModuleService { - return (&module{}).With(context.Background()) + return (&module{ + prmSvc: DefaultPermissions, + }).With(context.Background()) } -func (s *module) With(ctx context.Context) ModuleService { +func (svc *module) With(ctx context.Context) ModuleService { db := repository.DB(ctx) return &module{ - db: db, - ctx: ctx, + db: db, + ctx: ctx, + + prmSvc: svc.prmSvc.With(ctx), + moduleRepo: repository.Module(ctx, db), pageRepo: repository.Page(ctx, db), } } -func (s *module) FindByID(id uint64) (*types.Module, error) { - mod, err := s.moduleRepo.FindByID(id) - if err != nil { - return nil, err +func (svc *module) FindByID(id uint64) (m *types.Module, err error) { + if m, err = svc.moduleRepo.FindByID(id); err != nil { + return + } else if !svc.prmSvc.CanReadModule(m) { + return nil, errors.New("not allowed to access this module") } - return mod, err + return } -func (s *module) Find() ([]*types.Module, error) { - return s.moduleRepo.Find() +func (svc *module) Find() (mm types.ModuleSet, err error) { + if mm, err = svc.moduleRepo.Find(); err != nil { + return nil, err + } else { + return mm.Filter(func(m *types.Module) (bool, error) { + return svc.prmSvc.CanReadModule(m), nil + }) + } } -func (s *module) Create(mod *types.Module) (*types.Module, error) { +func (svc *module) Create(mod *types.Module) (*types.Module, error) { + if !svc.prmSvc.CanCreateModule() { + return nil, errors.New("not allowed to create this module") + } + if len(mod.Fields) == 0 { return nil, errors.New("Error creating module: no fields") } - return s.moduleRepo.Create(mod) + + return svc.moduleRepo.Create(mod) } -func (s *module) Update(module *types.Module) (m *types.Module, err error) { +func (svc *module) Update(module *types.Module) (m *types.Module, err error) { validate := func() error { if module.ID == 0 { return errors.New("Error updating module: invalid ID") - } else if m, err = s.moduleRepo.FindByID(module.ID); err != nil { + } else if m, err = svc.moduleRepo.FindByID(module.ID); err != nil { return errors.Wrap(err, "Error while loading module for update") } else { + if !svc.prmSvc.CanUpdateModule(m) { + return errors.New("not allowed to update this module") + } + module.CreatedAt = m.CreatedAt } @@ -86,12 +109,16 @@ func (s *module) Update(module *types.Module) (m *types.Module, err error) { return nil, err } - return m, s.db.Transaction(func() (err error) { - m, err = s.moduleRepo.Update(module) + return m, svc.db.Transaction(func() (err error) { + m, err = svc.moduleRepo.Update(module) return }) } -func (s *module) DeleteByID(id uint64) error { - return s.moduleRepo.DeleteByID(id) +func (svc *module) DeleteByID(ID uint64) error { + if !svc.prmSvc.CanDeleteModuleByID(ID) { + return errors.New("not allowed to delete this module") + } + + return svc.moduleRepo.DeleteByID(ID) } diff --git a/crm/service/page.go b/crm/service/page.go index fa19742a2..3fad01a76 100644 --- a/crm/service/page.go +++ b/crm/service/page.go @@ -2,8 +2,8 @@ package service import ( "context" - "errors" + "github.com/pkg/errors" "github.com/titpetric/factory" "github.com/crusttech/crust/crm/repository" @@ -15,6 +15,8 @@ type ( db *factory.DB ctx context.Context + prmSvc PermissionsService + pageRepo repository.PageRepository moduleRepo repository.ModuleRepository } @@ -38,45 +40,59 @@ type ( ) func Page() PageService { - return (&page{}).With(context.Background()) + return (&page{ + prmSvc: DefaultPermissions, + }).With(context.Background()) } -func (s *page) With(ctx context.Context) PageService { +func (svc *page) With(ctx context.Context) PageService { db := repository.DB(ctx) return &page{ - db: db, - ctx: ctx, + db: db, + ctx: ctx, + + prmSvc: svc.prmSvc.With(ctx), + pageRepo: repository.Page(ctx, db), moduleRepo: repository.Module(ctx, db), } } -func (s *page) FindByID(id uint64) (*types.Page, error) { - return s.pageRepo.FindByID(id) +func (svc *page) FindByID(id uint64) (p *types.Page, err error) { + return svc.checkPermissions(svc.pageRepo.FindByID(id)) } -func (s *page) FindByModuleID(moduleID uint64) (*types.Page, error) { - return s.pageRepo.FindByModuleID(moduleID) +func (svc *page) FindByModuleID(moduleID uint64) (p *types.Page, err error) { + return svc.checkPermissions(svc.pageRepo.FindByModuleID(moduleID)) } -func (s *page) FindBySelfID(selfID uint64) (pages types.PageSet, err error) { - return s.pageRepo.FindBySelfID(selfID) +func (svc *page) checkPermissions(p *types.Page, err error) (*types.Page, error) { + if err != nil { + return nil, err + } else if !svc.prmSvc.CanReadPage(p) { + return nil, errors.New("not allowed to access this page") + } + + return p, err } -func (s *page) Find() (pages types.PageSet, err error) { - return s.pageRepo.Find() +func (svc *page) FindBySelfID(selfID uint64) (pp types.PageSet, err error) { + return svc.filterPageSet(svc.pageRepo.FindBySelfID(selfID)) } -func (s *page) Tree() (pages types.PageSet, err error) { +func (svc *page) Find() (pages types.PageSet, err error) { + return svc.filterPageSet(svc.pageRepo.Find()) +} + +func (svc *page) Tree() (pages types.PageSet, err error) { var tree types.PageSet - return tree, s.db.Transaction(func() (err error) { - if pages, err = s.pageRepo.Find(); err != nil { + return tree, svc.db.Transaction(func() (err error) { + if pages, err = svc.filterPageSet(svc.pageRepo.Find()); err != nil { return } // No preloading - we do not need (or should have) any modules - // associated with us _ = pages.Walk(func(p *types.Page) error { if p.SelfID == 0 { @@ -100,19 +116,32 @@ func (s *page) Tree() (pages types.PageSet, err error) { }) } -func (s *page) FindRecordPages() (pages types.PageSet, err error) { - return s.pageRepo.FindRecordPages() +func (svc *page) FindRecordPages() (pages types.PageSet, err error) { + return svc.pageRepo.FindRecordPages() } -func (s *page) Reorder(selfID uint64, pageIDs []uint64) error { - return s.pageRepo.Reorder(selfID, pageIDs) +func (svc *page) filterPageSet(pp types.PageSet, err error) (types.PageSet, error) { + if err != nil { + return nil, err + } + + return pp.Filter(func(m *types.Page) (bool, error) { + return svc.prmSvc.CanReadPage(m), nil + }) } -func (s *page) Create(page *types.Page) (p *types.Page, err error) { +func (svc *page) Reorder(selfID uint64, pageIDs []uint64) error { + return svc.pageRepo.Reorder(selfID, pageIDs) +} + +func (svc *page) Create(page *types.Page) (p *types.Page, err error) { validate := func() error { + if !svc.prmSvc.CanCreatePage() { + return errors.New("not allowed to create this module") + } + if page.ModuleID > 0 { - // @todo check if module exists! - if p, err = s.pageRepo.FindByModuleID(page.ModuleID); err != nil { + if p, err = svc.pageRepo.FindByModuleID(page.ModuleID); err != nil { return err } else if p.ID > 0 { return errors.New("Page for module already exists") @@ -123,20 +152,26 @@ func (s *page) Create(page *types.Page) (p *types.Page, err error) { if err := validate(); err != nil { return nil, err } - return p, s.db.Transaction(func() (err error) { - p, err = s.pageRepo.Create(page) + return p, svc.db.Transaction(func() (err error) { + p, err = svc.pageRepo.Create(page) return }) } -func (s *page) Update(page *types.Page) (p *types.Page, err error) { +func (svc *page) Update(page *types.Page) (p *types.Page, err error) { validate := func() error { if page.ID == 0 { return errors.New("Error when savig page, invalid ID") + } else if p, err = svc.pageRepo.FindByID(page.ID); err != nil { + return errors.Wrap(err, "Error while loading module for update") + } else { + if !svc.prmSvc.CanUpdatePage(p) { + return errors.New("not allowed to update this pahe") + } } + if page.ModuleID > 0 { - // @todo check if module exists! - if p, err = s.pageRepo.FindByModuleID(page.ModuleID); err != nil { + if p, err = svc.pageRepo.FindByModuleID(page.ModuleID); err != nil { return err } else if p.ID > 0 && page.ID != p.ID { return errors.New("Page for module already exists") @@ -147,12 +182,16 @@ func (s *page) Update(page *types.Page) (p *types.Page, err error) { if err := validate(); err != nil { return nil, err } - return p, s.db.Transaction(func() (err error) { - p, err = s.pageRepo.Update(page) + return p, svc.db.Transaction(func() (err error) { + p, err = svc.pageRepo.Update(page) return }) } -func (s *page) DeleteByID(id uint64) error { - return s.pageRepo.DeleteByID(id) +func (svc *page) DeleteByID(ID uint64) error { + if !svc.prmSvc.CanDeletePageByID(ID) { + return errors.New("not allowed to delete this page") + } + + return svc.pageRepo.DeleteByID(ID) } diff --git a/crm/service/record.go b/crm/service/record.go index eb6dc16db..299abdb51 100644 --- a/crm/service/record.go +++ b/crm/service/record.go @@ -20,10 +20,11 @@ type ( db *factory.DB ctx context.Context + prmSvc PermissionsService + userSvc systemService.UserService + repository repository.RecordRepository moduleRepo repository.ModuleRepository - - userSvc systemService.UserService } RecordService interface { @@ -45,6 +46,7 @@ type ( func Record() RecordService { return (&record{ + prmSvc: DefaultPermissions, userSvc: systemService.DefaultUser, }).With(context.Background()) } @@ -52,11 +54,14 @@ func Record() RecordService { func (svc *record) With(ctx context.Context) RecordService { db := repository.DB(ctx) return &record{ - db: db, - ctx: ctx, + db: db, + ctx: ctx, + + prmSvc: svc.prmSvc.With(ctx), + userSvc: svc.userSvc.With(ctx), + repository: repository.Record(ctx, db), moduleRepo: repository.Module(ctx, db), - userSvc: svc.userSvc.With(ctx), } } @@ -64,6 +69,8 @@ func (svc *record) FindByID(recordID uint64) (r *types.Record, err error) { err = svc.db.Transaction(func() (err error) { if r, err = svc.repository.FindByID(recordID); err != nil { return + } else if !svc.prmSvc.CanReadRecord(r) { + return errors.New("not allowed to access this record") } if err = svc.preloadValues(r); err != nil { @@ -82,6 +89,8 @@ func (svc *record) Report(moduleID uint64, metrics, dimensions, filter string) ( err = svc.db.Transaction(func() (err error) { if module, err = svc.moduleRepo.FindByID(moduleID); err != nil { return + } else if !svc.prmSvc.CanReadRecord(module) { + return errors.New("not allowed to access this record") } out, err = svc.repository.Report(module, metrics, dimensions, filter) @@ -97,6 +106,8 @@ func (svc *record) Find(moduleID uint64, filter, sort string, page, perPage int) err = svc.db.Transaction(func() (err error) { if module, err = svc.moduleRepo.FindByID(moduleID); err != nil { return + } else if !svc.prmSvc.CanReadRecord(module) { + return errors.New("not allowed to access this record") } if rsp, err = svc.repository.Find(module, filter, sort, page, perPage); err != nil { @@ -120,6 +131,8 @@ func (svc *record) Create(in *types.Record) (record *types.Record, err error) { err = svc.db.Transaction(func() (err error) { if module, err = svc.moduleRepo.FindByID(in.ModuleID); err != nil { return + } else if !svc.prmSvc.CanCreateRecord(module) { + return errors.New("not allowed to create records for this module") } if err = svc.sanitizeValues(module, in.Values); err != nil { @@ -157,6 +170,8 @@ func (svc *record) Update(updated *types.Record) (record *types.Record, err erro if record, err = svc.repository.FindByID(updated.ID); err != nil { return errors.Wrap(err, "nonexistent record") + } else if !svc.prmSvc.CanUpdateRecord(record) { + return errors.New("not allowed to update this record") } if module, err = svc.moduleRepo.FindByID(updated.ModuleID); err != nil { diff --git a/crm/service/service.go b/crm/service/service.go index 4d4f9afbd..391e74996 100644 --- a/crm/service/service.go +++ b/crm/service/service.go @@ -32,13 +32,13 @@ func Init() { log.Fatalf("Failed to initialize store: %v", err) } + DefaultPermissions = Permissions() DefaultRecord = Record() DefaultModule = Module() DefaultTrigger = Trigger() DefaultPage = Page() DefaultChart = Chart() DefaultNotification = Notification() - DefaultPermissions = Permissions() DefaultAttachment = Attachment(fs) }) } diff --git a/crm/service/trigger.go b/crm/service/trigger.go index 4ec455ed6..41d4d082b 100644 --- a/crm/service/trigger.go +++ b/crm/service/trigger.go @@ -2,8 +2,8 @@ package service import ( "context" - "errors" + "github.com/pkg/errors" "github.com/titpetric/factory" "github.com/crusttech/crust/crm/repository" @@ -15,6 +15,8 @@ type ( db *factory.DB ctx context.Context + prmSvc PermissionsService + triggerRepo repository.TriggerRepository moduleRepo repository.ModuleRepository } @@ -32,56 +34,85 @@ type ( ) func Trigger() TriggerService { - return (&trigger{}).With(context.Background()) + return (&trigger{ + prmSvc: DefaultPermissions}).With(context.Background()) } -func (s *trigger) With(ctx context.Context) TriggerService { +func (svc *trigger) With(ctx context.Context) TriggerService { db := repository.DB(ctx) return &trigger{ - db: db, - ctx: ctx, + db: db, + ctx: ctx, + + prmSvc: svc.prmSvc.With(ctx), + triggerRepo: repository.Trigger(ctx, db), moduleRepo: repository.Module(ctx, db), } } -func (s *trigger) FindByID(id uint64) (*types.Trigger, error) { - return s.triggerRepo.FindByID(id) -} - -func (s *trigger) Find(filter types.TriggerFilter) (types.TriggerSet, error) { - return s.triggerRepo.Find(filter) -} - -func (s *trigger) Create(trigger *types.Trigger) (p *types.Trigger, err error) { - validate := func() error { - return nil +func (svc *trigger) FindByID(id uint64) (t *types.Trigger, err error) { + if t, err = svc.triggerRepo.FindByID(id); err != nil { + return + } else if !svc.prmSvc.CanReadTrigger(t) { + return nil, errors.New("not allowed to access this trigger") } - if err := validate(); err != nil { + + return +} + +func (svc *trigger) Find(filter types.TriggerFilter) (tt types.TriggerSet, err error) { + if tt, err = svc.triggerRepo.Find(filter); err != nil { return nil, err + } else { + return tt.Filter(func(m *types.Trigger) (bool, error) { + return svc.prmSvc.CanReadTrigger(m), nil + }) } - return p, s.db.Transaction(func() (err error) { - p, err = s.triggerRepo.Create(trigger) +} + +func (svc *trigger) Create(trigger *types.Trigger) (p *types.Trigger, err error) { + if !svc.prmSvc.CanCreateTrigger() { + return nil, errors.New("not allowed to create this trigger") + } + + return p, svc.db.Transaction(func() (err error) { + p, err = svc.triggerRepo.Create(trigger) return }) } -func (s *trigger) Update(trigger *types.Trigger) (p *types.Trigger, err error) { +func (svc *trigger) Update(trigger *types.Trigger) (t *types.Trigger, err error) { validate := func() error { if trigger.ID == 0 { - return errors.New("could not update trigger, invalid ID") + return errors.New("Error updating trigger: invalid ID") + } else if t, err = svc.triggerRepo.FindByID(trigger.ID); err != nil { + return errors.Wrap(err, "Error while loading trigger for update") + } else { + if !svc.prmSvc.CanUpdateModule(t) { + return errors.New("not allowed to update this trigger") + } + + trigger.CreatedAt = t.CreatedAt } + return nil } + if err := validate(); err != nil { return nil, err } - return p, s.db.Transaction(func() (err error) { - p, err = s.triggerRepo.Update(trigger) + + return t, svc.db.Transaction(func() (err error) { + t, err = svc.triggerRepo.Update(trigger) return }) } -func (s *trigger) DeleteByID(id uint64) error { - return s.triggerRepo.DeleteByID(id) +func (svc *trigger) DeleteByID(ID uint64) error { + if !svc.prmSvc.CanDeleteTriggerByID(ID) { + return errors.New("not allowed to delete this trigger") + } + + return svc.triggerRepo.DeleteByID(ID) }