add(internal): rules feature
- add make test.rules target - add rules api implementation - add sql migration for rules table - add unit tests to verify it works
This commit is contained in:
parent
5825197a81
commit
d17a6fb8aa
4
Makefile
4
Makefile
@ -114,6 +114,10 @@ test.rbac: $(GOTEST)
|
||||
$(GOTEST) -covermode count -coverprofile .cover.out -v ./internal/rbac/...
|
||||
$(GO) tool cover -func=.cover.out | grep --color "^\|[^0-9]0.0%"
|
||||
|
||||
test.rules: $(GOTEST)
|
||||
$(GOTEST) -covermode count -coverprofile .cover.out -v ./internal/rules/...
|
||||
$(GO) tool cover -func=.cover.out | grep --color "^\|[^0-9]0.0%"
|
||||
|
||||
test.mail: $(GOTEST)
|
||||
$(GOTEST) -covermode count -coverprofile .cover.out -v ./internal/mail/...
|
||||
$(GO) tool cover -func=.cover.out | grep --color "^\|[^0-9]0.0%"
|
||||
|
||||
16
internal/rules/interfaces.go
Normal file
16
internal/rules/interfaces.go
Normal file
@ -0,0 +1,16 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
)
|
||||
|
||||
type ResourcesInterface interface {
|
||||
With(ctx context.Context, db *factory.DB) ResourcesInterface
|
||||
|
||||
CheckAccessMulti(resource string, operation string) error
|
||||
CheckAccess(resource string, operation string) error
|
||||
|
||||
Grant(resource string, teamID uint64, operations []string, value Access) error
|
||||
}
|
||||
64
internal/rules/main_test.go
Normal file
64
internal/rules/main_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package rules_test
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/namsral/flag"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
systemMigrate "github.com/crusttech/crust/system/db"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// @todo this is a very optimistic initialization, make it more robust
|
||||
godotenv.Load("../../.env")
|
||||
|
||||
prefix := "system"
|
||||
dsn := ""
|
||||
|
||||
p := func(s string) string {
|
||||
return prefix + "-" + s
|
||||
}
|
||||
|
||||
flag.StringVar(&dsn, p("db-dsn"), "crust:crust@tcp(db1:3306)/crust?collation=utf8mb4_general_ci", "DSN for database connection")
|
||||
flag.Parse()
|
||||
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
factory.Database.Add("default", dsn)
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
db.Profiler = &factory.Database.ProfilerStdout
|
||||
|
||||
// migrate database schema
|
||||
if err := systemMigrate.Migrate(db); err != nil {
|
||||
log.Printf("Error running migrations: %+v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func assert(t *testing.T, ok bool, format string, args ...interface{}) bool {
|
||||
if !ok {
|
||||
t.Fatalf(format, args...)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
func must(t *testing.T, err error, message ...string) {
|
||||
if len(message) > 0 {
|
||||
assert(t, err == nil, message[0]+": %+v", err)
|
||||
return
|
||||
}
|
||||
assert(t, err == nil, "Error: %+v", err)
|
||||
}
|
||||
|
||||
func mustFail(t *testing.T, err error) {
|
||||
assert(t, err != nil, "Expected error, got nil")
|
||||
}
|
||||
123
internal/rules/resources.go
Normal file
123
internal/rules/resources.go
Normal file
@ -0,0 +1,123 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
type Access string
|
||||
|
||||
const (
|
||||
Allow Access = "yes"
|
||||
Deny = "no"
|
||||
Inherit = ""
|
||||
)
|
||||
|
||||
type resources struct {
|
||||
ctx context.Context
|
||||
db *factory.DB
|
||||
}
|
||||
|
||||
func NewResources(ctx context.Context, db *factory.DB) ResourcesInterface {
|
||||
return (&resources{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r *resources) With(ctx context.Context, db *factory.DB) ResourcesInterface {
|
||||
return &resources{
|
||||
ctx: ctx,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resources) identity() uint64 {
|
||||
return auth.GetIdentityFromContext(r.ctx).Identity()
|
||||
}
|
||||
|
||||
func (r *resources) CheckAccessMulti(resource string, operation string) error {
|
||||
user := r.identity()
|
||||
result := []Access{}
|
||||
query := []string{
|
||||
// select rules
|
||||
"select r.value from sys_rules r",
|
||||
// join members
|
||||
"inner join sys_team_member m on (m.rel_team = r.rel_team and m.rel_user=?)",
|
||||
// add conditions
|
||||
"where r.resource LIKE ? and r.operation=?",
|
||||
}
|
||||
resource = strings.Replace(resource, "*", "%", -1)
|
||||
queryString := strings.Join(query, " ")
|
||||
if err := r.db.Select(&result, queryString, user, resource, operation); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// order by deny, allow
|
||||
for _, val := range result {
|
||||
if val == Deny {
|
||||
return errors.New("Access not allowed")
|
||||
}
|
||||
}
|
||||
for _, val := range result {
|
||||
if val == Allow {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("Access not allowed")
|
||||
}
|
||||
|
||||
func (r *resources) CheckAccess(resource string, operation string) error {
|
||||
user := r.identity()
|
||||
result := []Access{}
|
||||
query := []string{
|
||||
// select rules
|
||||
"select r.value from sys_rules r",
|
||||
// join members
|
||||
"inner join sys_team_member m on (m.rel_team = r.rel_team and m.rel_user=?)",
|
||||
// add conditions
|
||||
"where r.resource=? and r.operation=?",
|
||||
}
|
||||
queryString := strings.Join(query, " ")
|
||||
if err := r.db.Select(&result, queryString, user, resource, operation); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// order by deny, allow
|
||||
for _, val := range result {
|
||||
if val == Deny {
|
||||
return errors.New("Access not allowed")
|
||||
}
|
||||
}
|
||||
for _, val := range result {
|
||||
if val == Allow {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("Access not allowed")
|
||||
}
|
||||
|
||||
func (r *resources) Grant(resource string, teamID uint64, operations []string, value Access) error {
|
||||
row := types.Rules{
|
||||
TeamID: teamID,
|
||||
Resource: resource,
|
||||
Value: string(value),
|
||||
}
|
||||
var err error
|
||||
for _, operation := range operations {
|
||||
row.Operation = operation
|
||||
switch value {
|
||||
case Inherit:
|
||||
_, err = r.db.NamedExec("delete from sys_rules where rel_team=:rel_team and resource=:resource and operation=:operation", row)
|
||||
default:
|
||||
err = r.db.Replace("sys_rules", row)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
64
internal/rules/resources_test.go
Normal file
64
internal/rules/resources_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package rules_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/rules"
|
||||
"github.com/crusttech/crust/system/types"
|
||||
)
|
||||
|
||||
func TestRules(t *testing.T) {
|
||||
user := &types.User{ID: 1337}
|
||||
ctx := auth.SetIdentityToContext(context.Background(), user)
|
||||
|
||||
db := factory.Database.MustGet()
|
||||
mustFail(t, db.Transaction(func() error {
|
||||
db.Insert("sys_user", user)
|
||||
var i uint64 = 0
|
||||
for i < 5 {
|
||||
db.Insert("sys_team", types.Team{ID: i, Name: fmt.Sprintf("Team %d", i)})
|
||||
i++
|
||||
}
|
||||
db.Insert("sys_team_member", types.TeamMember{TeamID: 1, UserID: user.ID})
|
||||
db.Insert("sys_team_member", types.TeamMember{TeamID: 2, UserID: user.ID})
|
||||
|
||||
resources := rules.NewResources(ctx, db)
|
||||
|
||||
// default (unset=deny)
|
||||
{
|
||||
mustFail(t, resources.CheckAccess("channel:1", "edit"))
|
||||
mustFail(t, resources.CheckAccessMulti("channel:*", "edit"))
|
||||
}
|
||||
|
||||
// allow channel:2 group:2 (default deny, multi=allow)
|
||||
{
|
||||
resources.Grant("channel:2", 2, []string{"edit", "delete"}, rules.Allow)
|
||||
mustFail(t, resources.CheckAccess("channel:1", "edit"))
|
||||
must(t, resources.CheckAccess("channel:2", "edit"))
|
||||
must(t, resources.CheckAccessMulti("channel:*", "edit"))
|
||||
}
|
||||
|
||||
// deny channel:1 group:1 (explicit deny, multi=deny)
|
||||
{
|
||||
resources.Grant("channel:1", 1, []string{"edit"}, rules.Deny)
|
||||
mustFail(t, resources.CheckAccess("channel:1", "edit"))
|
||||
must(t, resources.CheckAccess("channel:2", "edit"))
|
||||
mustFail(t, resources.CheckAccessMulti("channel:*", "edit"))
|
||||
}
|
||||
|
||||
// reset (unset=deny)
|
||||
{
|
||||
resources.Grant("channel:2", 2, []string{"edit", "delete"}, rules.Inherit)
|
||||
resources.Grant("channel:1", 1, []string{"edit", "delete"}, rules.Inherit)
|
||||
mustFail(t, resources.CheckAccess("channel:1", "edit"))
|
||||
mustFail(t, resources.CheckAccessMulti("channel:*", "edit"))
|
||||
}
|
||||
return errors.New("Rollback")
|
||||
}))
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
8
system/db/schema/mysql/20190116102104.rules.up.sql
Normal file
8
system/db/schema/mysql/20190116102104.rules.up.sql
Normal file
@ -0,0 +1,8 @@
|
||||
CREATE TABLE `sys_rules` (
|
||||
`rel_team` BIGINT UNSIGNED NOT NULL,
|
||||
`resource` VARCHAR(128) NOT NULL,
|
||||
`operation` VARCHAR(128) NOT NULL,
|
||||
`value` ENUM('no', 'yes') NOT NULL,
|
||||
|
||||
PRIMARY KEY (`rel_team`, `resource`, `operation`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
8
system/types/rules.go
Normal file
8
system/types/rules.go
Normal file
@ -0,0 +1,8 @@
|
||||
package types
|
||||
|
||||
type Rules struct {
|
||||
TeamID uint64 `db:"rel_team"`
|
||||
Resource string `db:"resource"`
|
||||
Operation string `db:"operation"`
|
||||
Value string `db:"value"`
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user