diff --git a/server/app/resources.cue b/server/app/resources.cue index 3a9be8a78..001d945d2 100644 --- a/server/app/resources.cue +++ b/server/app/resources.cue @@ -51,6 +51,16 @@ resources: { [key=_]: {"handle": key, "component": "system", "platform": "cortez } } + filter: { + struct: { + resource: {goType: "[]string", ident: "resource", storeIdent: "resource"} + operation: {goType: "string", ident: "operation", storeIdent: "operation"} + role_id: {goType: "uint64", ident: "roleID", storeIdent: "rel_role"} + } + + byValue: ["resource", "operation", "role_id"] + } + store: { ident: "rbacRule" diff --git a/server/pkg/rbac/roles.go b/server/pkg/rbac/roles.go index c4ab27e97..9267b3774 100644 --- a/server/pkg/rbac/roles.go +++ b/server/pkg/rbac/roles.go @@ -116,7 +116,7 @@ func statRoles(rr ...*Role) (stats map[roleKind]int) { } // compare list of session roles (ids) with preloaded roles and calculate the final list -func getContextRoles(s Session, res Resource, preloadedRoles []*Role) (out partRoles) { +func getContextRoles(s Session, res Resource, preloadedRoles ...*Role) (out partRoles) { var ( mm = slice.ToUint64BoolMap(s.Roles()) scope = make(map[string]interface{}) diff --git a/server/pkg/rbac/roles_test.go b/server/pkg/rbac/roles_test.go index f6172cafd..e18b2edff 100644 --- a/server/pkg/rbac/roles_test.go +++ b/server/pkg/rbac/roles_test.go @@ -97,7 +97,7 @@ func Test_getContextRoles(t *testing.T) { req = require.New(t) ) - req.Equal(partitionRoles(tc.output...), getContextRoles(&session{rr: tc.sessionRoles}, tc.res, tc.preloadRoles)) + req.Equal(partitionRoles(tc.output...), getContextRoles(&session{rr: tc.sessionRoles}, tc.res, tc.preloadRoles...)) }) } } diff --git a/server/pkg/rbac/service.go b/server/pkg/rbac/service.go index 8d58c83e9..d1bf92cb2 100644 --- a/server/pkg/rbac/service.go +++ b/server/pkg/rbac/service.go @@ -28,6 +28,10 @@ type ( // RuleFilter is a dummy struct to satisfy store codegen RuleFilter struct { + Resource []string + Operation string + RoleID uint64 + Limit uint } @@ -94,7 +98,7 @@ func (svc *service) Can(ses Session, op string, res Resource) bool { // See RuleSet's Check() func for details func (svc *service) Check(ses Session, op string, res Resource) (a Access) { var ( - fRoles = getContextRoles(ses, res, svc.roles) + fRoles = getContextRoles(ses, res, svc.roles...) ) if hasWildcards(res.RbacResource()) { @@ -150,7 +154,7 @@ func (svc *service) Trace(ses Session, op string, res Resource) *Trace { } var ( - fRoles = getContextRoles(ses, res, svc.roles) + fRoles = getContextRoles(ses, res, svc.roles...) ) _ = check(svc.indexed, fRoles, op, res.RbacResource(), t) diff --git a/server/pkg/rbac/store_interface.go b/server/pkg/rbac/store_interface.go index e98ee22ba..75a417be3 100644 --- a/server/pkg/rbac/store_interface.go +++ b/server/pkg/rbac/store_interface.go @@ -2,6 +2,8 @@ package rbac import ( "context" + + "github.com/cortezaproject/corteza/server/system/types" ) type ( @@ -11,5 +13,8 @@ type ( UpsertRbacRule(ctx context.Context, rr ...*Rule) error DeleteRbacRule(ctx context.Context, rr ...*Rule) error TruncateRbacRules(ctx context.Context) error + + // @todo this isn't ok since we're referencing sys types + SearchRoles(ctx context.Context, f types.RoleFilter) (types.RoleSet, types.RoleFilter, error) } ) diff --git a/server/pkg/rbac/wrapper.go b/server/pkg/rbac/wrapper.go new file mode 100644 index 000000000..48ab7a7fa --- /dev/null +++ b/server/pkg/rbac/wrapper.go @@ -0,0 +1,430 @@ +package rbac + +import ( + "context" + "fmt" + "math" + "sort" + "strings" + "time" + + "github.com/cortezaproject/corteza/server/pkg/filter" + "github.com/cortezaproject/corteza/server/system/types" + "github.com/davecgh/go-spew/spew" +) + +type ( + WrapperConfig struct { + InitialIndexedRoles []uint64 + MaxIndexSize int + } + + wrapperService struct { + cfg WrapperConfig + + store rbacRulesStore + counter *usageCounter + index *wrapperIndex + roles []*Role + } +) + +func dftWrapperCfg(base WrapperConfig) (out WrapperConfig) { + out = base + + if base.MaxIndexSize == 0 { + out.MaxIndexSize = -1 + } + + return out +} + +func Wrapper(ctx context.Context, store rbacRulesStore, cc WrapperConfig) (x *wrapperService, err error) { + cc = dftWrapperCfg(cc) + + uc := &usageCounter{ + incChan: make(chan uint64, 256), + sigChan: make(chan counterEntry, 8), + } + + x = &wrapperService{ + cfg: cc, + + store: store, + counter: uc, + } + + x.roles, err = x.loadRoles(ctx, store) + if err != nil { + return + } + + x.index, err = x.loadIndex(ctx, store, x.roles) + if err != nil { + return + } + + uc.watch(ctx) + x.watch(ctx) + + return +} + +func (svc *wrapperService) Clear() { + svc.store = nil + svc.counter = nil + svc.index = nil + svc.roles = nil +} + +func (svc *wrapperService) Can(ses Session, op string, res Resource) (ok bool, err error) { + ac, err := svc.Check(ses, op, res) + if err != nil { + return + } + + return ac == Allow, nil +} + +func (svc *wrapperService) Check(ses Session, op string, res Resource) (a Access, err error) { + if hasWildcards(res.RbacResource()) { + // prevent use of wildcard resources for checking permissions + return Inherit, nil + } + + fRoles := getContextRoles(ses, res, svc.roles...) + + return svc.check(ses.Context(), fRoles, op, res.RbacResource()) +} + +func (svc *wrapperService) check(ctx context.Context, rolesByKind partRoles, op, res string) (a Access, err error) { + if member(rolesByKind, AnonymousRole) && len(rolesByKind) > 1 { + // Integrity check; when user is member of anonymous role + // should not be member of any other type of role + return resolve(nil, Deny, failedIntegrityCheck), nil + } + + if member(rolesByKind, BypassRole) { + // if user has at least one bypass role, we allow access + return resolve(nil, Allow, bypassRoleMembership), nil + } + + // if indexedRules.empty() { + // // no rules to check + // return resolve(nil, Inherit, noRules) + // } + + var ( + match *Rule + allowed bool + ) + + indexed, unindexed, err := svc.segmentRoles(ctx, rolesByKind) + if err != nil { + return Inherit, err + } + + // + // if trace != nil { + // // from this point on, there is a chance trace (if set) + // // will contain some rules. + // // + // // Stable order needs to be ensured: there is no production + // // code that relies on that but tests might fail and API + // // response would be flaky. + // defer sortTraceRules(trace) + // } + + st := evlState{ + op: op, + res: res, + + unindexedRoles: unindexed, + indexedRoles: indexed, + } + + st.unindexedRules, err = svc.pullUnindexed(ctx, unindexed, op, res) + if err != nil { + return Inherit, err + } + + // Priority is important here. We want to have + // stable RBAC check behaviour and ability + // to override allow/deny depending on how niche the role (type) is: + // - context (eg owners) are more niche than common + // - rules for common roles are more important than authenticated and anonymous role types + // + // Note that bypass roles are intentionally ignored here; if user is member of + // bypass role there is no need to check any other rule + for _, kind := range []roleKind{ContextRole, CommonRole, AuthenticatedRole, AnonymousRole} { + // not a member of any role of this kind + if len(rolesByKind[kind]) == 0 { + continue + } + + // reset allowed to false + // for each role kind + allowed = false + + for r := range rolesByKind[kind] { + match = svc.getMatching(st, kind, r) + + // check all rules for each role the security-context + if match == nil { + // no rules match + continue + } + + // if trace != nil { + // // if trace is enabled, append + // // each matching rule + // trace.Rules = append(trace.Rules, match) + // } + + if match.Access == Deny { + // if we stumble upon Deny we short-circuit the check + return resolve(nil, Deny, ""), nil + } + + if match.Access == Allow { + // allow rule found, we need to check rules on other roles + // before we allow it + allowed = true + } + } + + if allowed { + // at least one of the roles (per role type) in the security context + // allows operation on a resource + return resolve(nil, Allow, ""), nil + } + } + + // No rule matched + return resolve(nil, Inherit, noMatch), nil +} + +func (svc *wrapperService) segmentRoles(ctx context.Context, roles partRoles) (indexed, unindexed partRoles, err error) { + unindexed = partRoles{} + indexed = partRoles{} + + unindexed[CommonRole] = make(map[uint64]bool) + indexed[CommonRole] = make(map[uint64]bool) + + for k, rg := range roles { + for r := range rg { + if svc.index.hasRole(r) { + indexed[k][r] = true + continue + } + + unindexed[k][r] = true + } + } + + return +} + +type ( + evlState struct { + unindexedRoles partRoles + indexedRoles partRoles + + unindexedRules [5]map[uint64][]*Rule + + res string + op string + } +) + +func (svc *wrapperService) getMatching(st evlState, kind roleKind, role uint64) (rule *Rule) { + var ( + aux []*Rule + rules RuleSet + ) + + // Indexed + aux = svc.index.get(role, st.op, st.res) + rules = append(rules, aux...) + + // Unindexed + aux = st.unindexedRules[kind][role] + rules = append(rules, aux...) + + set := RuleSet(rules) + sort.Sort(set) + + for _, s := range set { + if s.Access == Inherit { + continue + } + + return s + } + + return nil +} + +func (svc *wrapperService) pullUnindexed(ctx context.Context, unindexed partRoles, op, res string) (out [5]map[uint64][]*Rule, err error) { + resPerm := make([]string, 0, 8) + resPerm = append(resPerm, res) + + // Get all the resource permissions + // @todo get permissions for parent resources; this will probs be some lookup table + rr := strings.Split(res, "/") + for i := len(rr) - 1; i >= 0; i-- { + rr[i] = "*" + resPerm = append(resPerm, strings.Join(rr, "/")) + } + + for rk, rr := range unindexed { + for r := range rr { + auxRr := make([]*Rule, 0, 4) + auxRr, _, err = svc.store.SearchRbacRules(ctx, RuleFilter{ + RoleID: r, + Resource: resPerm, + Operation: op, + }) + if err != nil { + return + } + + if out[rk] == nil { + out[rk] = map[uint64][]*Rule{ + r: auxRr, + } + } else { + out[rk][r] = auxRr + } + } + } + + return +} + +func (svc *wrapperService) IndexRoleChange(ctx context.Context, roleID uint64) (err error) { + aux, _, err := svc.store.SearchRbacRules(ctx, RuleFilter{ + RoleID: roleID, + }) + if err != nil { + return + } + + // @todo cap this + if len(svc.index.rules.children) > svc.cfg.MaxIndexSize { + // @note probably remove a few extra just to avoid constantly doing this + // @todo is this a good idea? Not sure if worth it since all of this is behind the scene anyways + wp := svc.counter.worstPerformers(4) + svc.index.remove(wp...) + } + + svc.index.add(aux...) + return +} + +func (svc *wrapperService) watch(ctx context.Context) { + t := time.NewTicker(time.Minute * 5) + + go func() { + for { + select { + case <-t.C: + spew.Dump("ticking") + + case change := <-svc.counter.sigChan: + err := svc.IndexRoleChange(ctx, change.key) + if err != nil { + spew.Dump("wrapper watch change err", err) + } + + case <-ctx.Done(): + return + } + } + }() +} + +// // // // // // // // // // // // // // // // // // // // // // // // // // + +func makeKey(op, res string, role uint64) string { + return fmt.Sprintf("%d:%s:%s", role, op, res) +} + +// + +// // // // // // // // // // // // // // // // // // // // // // // // // // +// Boilerplate & state management stuff + +func (svc *wrapperService) loadRoles(ctx context.Context, s rbacRulesStore) (out []*Role, err error) { + auxRoles, _, err := s.SearchRoles(ctx, types.RoleFilter{ + Paging: filter.Paging{ + Limit: 0, + }, + }) + if err != nil { + return + } + + for _, ar := range auxRoles { + out = append(out, &Role{ + id: ar.ID, + handle: ar.Handle, + kind: CommonRole, + }) + } + + return +} + +func (svc *wrapperService) loadIndex(ctx context.Context, s rbacRulesStore, allRoles []*Role) (out *wrapperIndex, err error) { + // @todo smarter way to figure out what/how many roles we want to load up + roles := svc.getIndexRoles(allRoles) + + rules := make(RuleSet, 0, 1024) + var aux RuleSet + for _, role := range roles { + aux, _, err = s.SearchRbacRules(ctx, RuleFilter{ + RoleID: role.id, + Limit: 0, + }) + if err != nil { + return + } + + rules = append(rules, aux...) + } + + out = &wrapperIndex{ + rules: buildRuleIndex(rules), + } + + return +} + +func (svc *wrapperService) getIndexRoles(allRoles []*Role) (out []*Role) { + // User-specified what we want to index; respect that to the t + if len(svc.cfg.InitialIndexedRoles) > 0 { + for _, r := range allRoles { + for _, ir := range svc.cfg.InitialIndexedRoles { + if r.id == ir { + out = append(out, r) + } + } + } + + return + } + + // Straight up limit + // @todo add some counters to figure out which roles are most used from the start + if svc.cfg.MaxIndexSize == -1 { + return allRoles + } + + if svc.cfg.MaxIndexSize == 0 { + return nil + } + + // @todo smarter way to figure out what/how many roles we want to load up + return allRoles[:int(math.Min(float64(len(allRoles)), float64(svc.cfg.MaxIndexSize)))] +} diff --git a/server/pkg/rbac/wrapper_counter.go b/server/pkg/rbac/wrapper_counter.go new file mode 100644 index 000000000..fe0fb340e --- /dev/null +++ b/server/pkg/rbac/wrapper_counter.go @@ -0,0 +1,99 @@ +package rbac + +import ( + "context" + "sort" + "sync" + "time" +) + +type ( + usageCounter struct { + index map[uint64]uint + + lock sync.RWMutex + + sigThreshold uint + + incChan chan uint64 + sigChan chan counterEntry + } + + counterEntry struct { + key uint64 + count uint + } + + MinHeap []counterEntry +) + +func (svc *usageCounter) worstPerformers(n int) (out []uint64) { + svc.lock.RLock() + defer svc.lock.RUnlock() + + // Code to get n elements with the smallest count + + hh := make(MinHeap, 0, len(svc.index)) + for k, v := range svc.index { + hh = append(hh, counterEntry{key: k, count: v}) + } + + sort.Sort(hh) + + for _, x := range hh { + out = append(out, x.key) + + if len(out) >= n { + return + } + } + + return +} + +func (svc *usageCounter) inc(key uint64) { + svc.lock.Lock() + defer svc.lock.Unlock() + + count := svc.index[key] + 1 + svc.index[key] = count + + if count >= svc.sigThreshold { + delete(svc.index, key) + svc.sigChan <- counterEntry{key: key, count: count} + } +} + +func (svc *usageCounter) clean() { + svc.lock.Lock() + defer svc.lock.Unlock() + + for k, v := range svc.index { + if v < uint(float64(svc.sigThreshold)*0.05) { + delete(svc.index, k) + } + } +} + +func (svc *usageCounter) watch(ctx context.Context) { + cleanT := time.NewTicker(time.Minute * 10) + + go func() { + for { + select { + case <-ctx.Done(): + return + + case <-cleanT.C: + svc.clean() + + case key := <-svc.incChan: + svc.inc(key) + } + } + }() +} + +func (h MinHeap) Len() int { return len(h) } +func (h MinHeap) Less(i, j int) bool { return h[i].count < h[j].count } +func (h MinHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } diff --git a/server/pkg/rbac/wrapper_index.go b/server/pkg/rbac/wrapper_index.go new file mode 100644 index 000000000..3cde632c9 --- /dev/null +++ b/server/pkg/rbac/wrapper_index.go @@ -0,0 +1,40 @@ +package rbac + +import "sync" + +type ( + wrapperIndex struct { + mux sync.RWMutex + rules *ruleIndex + } +) + +func (svc *wrapperIndex) get(role uint64, op string, res string) (out []*Rule) { + svc.mux.RLock() + defer svc.mux.RUnlock() + + return svc.rules.get(role, op, res) +} + +func (svc *wrapperIndex) hasRole(role uint64) (ok bool) { + svc.mux.RLock() + defer svc.mux.RUnlock() + + _, ok = svc.rules.children[role] + return +} + +// @todo since it's like so, we might not need the trie to have deletable elements +func (svc *wrapperIndex) remove(roles ...uint64) { + svc.mux.Lock() + defer svc.mux.Unlock() + for _, r := range roles { + delete(svc.rules.children, r) + } +} + +func (svc *wrapperIndex) add(rules ...*Rule) { + svc.mux.Lock() + defer svc.mux.Unlock() + svc.rules.add(rules...) +} diff --git a/server/store/adapters/rdbms/filters.gen.go b/server/store/adapters/rdbms/filters.gen.go index f35415d2c..bb28b6025 100644 --- a/server/store/adapters/rdbms/filters.gen.go +++ b/server/store/adapters/rdbms/filters.gen.go @@ -1145,6 +1145,18 @@ func QueueMessageFilter(d drivers.Dialect, f systemType.QueueMessageFilter) (ee // This function is auto-generated func RbacRuleFilter(d drivers.Dialect, f rbacType.RuleFilter) (ee []goqu.Expression, _ rbacType.RuleFilter, err error) { + if ss := trimStringSlice(f.Resource); len(ss) > 0 { + ee = append(ee, goqu.C("resource").In(ss)) + } + + if val := strings.TrimSpace(f.Operation); len(val) > 0 { + ee = append(ee, goqu.C("operation").Eq(f.Operation)) + } + + if f.RoleID > 0 { + ee = append(ee, goqu.C("rel_role").Eq(f.RoleID)) + } + return ee, f, err }