From 9f494a37f94440b7abfe1d3a213216f03901b301 Mon Sep 17 00:00:00 2001 From: Mitja Zivkovic Date: Mon, 11 Mar 2019 21:13:07 +0100 Subject: [PATCH 1/2] upd(messaging) channel permission checks --- internal/rules/main_test.go | 10 ++ messaging/service/channel.go | 135 ++++++++--------- messaging/service/main_test.go | 3 + messaging/service/permissions.go | 139 ++++++++++++++---- messaging/service/permissions_test.go | 200 +++++++++++++------------- messaging/types/channel.go | 23 +-- system/service/validation.go | 8 +- 7 files changed, 313 insertions(+), 205 deletions(-) diff --git a/internal/rules/main_test.go b/internal/rules/main_test.go index 371b1e704..99c107320 100644 --- a/internal/rules/main_test.go +++ b/internal/rules/main_test.go @@ -37,5 +37,15 @@ func TestMain(m *testing.M) { return } + // clean up tables + { + for _, name := range []string{"sys_user", "sys_role", "sys_role_member", "sys_rules"} { + _, err := db.Exec("truncate " + name) + if err != nil { + panic("Error when clearing " + name + ": " + err.Error()) + } + } + } + os.Exit(m.Run()) } diff --git a/messaging/service/channel.go b/messaging/service/channel.go index 625972585..3e1b74d65 100644 --- a/messaging/service/channel.go +++ b/messaging/service/channel.go @@ -58,8 +58,8 @@ type ( ) var ( - ErrUnknownChannelType = errors.New("Unknown ChannelType") - ErrNoPermission = errors.New("No permissions") + ErrUnknownChannelType = errors.New("unknown ChannelType") + ErrNoPermission = errors.New("no permissions") ) const ( @@ -103,8 +103,8 @@ func (svc *channel) FindByID(ID uint64) (ch *types.Channel, err error) { return } - if !ch.CanObserve { - return nil, errors.New("Not allowed to access channel") + if !svc.prm.CanReadChannel(ch) { + return nil, errors.New("not allowed to access channel") } return @@ -121,7 +121,7 @@ func (svc *channel) Find(filter *types.ChannelFilter) (cc types.ChannelSet, err } cc, err = cc.Filter(func(c *types.Channel) (b bool, e error) { - return c.CanObserve, nil + return svc.prm.CanReadChannel(c), nil }) return @@ -219,32 +219,32 @@ func (svc *channel) Create(in *types.Channel) (out *types.Channel, err error) { // Group already exists so let's just return it return nil } else if out != nil && !out.CanObserve { - return errors.New("Not allowed to create this channel due to permission settings") + return errors.New("not allowed to create this channel due to permission settings") } } if in.Type == types.ChannelTypePublic && !svc.prm.CanCreatePublicChannel() { - return errors.New("Not allowed to create public channels") + return errors.New("not allowed to create public channels") } if in.Type == types.ChannelTypePrivate && !svc.prm.CanCreatePrivateChannel() { - return errors.New("Not allowed to create private channels") + return errors.New("not allowed to create private channels") } - if in.Type == types.ChannelTypeGroup && !svc.prm.CanCreateDirectChannel() { - return errors.New("Not allowed to create group channels") + if in.Type == types.ChannelTypeGroup && !svc.prm.CanCreateGroupChannel() { + return errors.New("not allowed to create group channels") } if len(in.Name) == 0 && in.Type != types.ChannelTypeGroup { - return errors.New("Channel name not provided") + return errors.New("channel name not provided") } if settingsChannelNameLength > 0 && len(in.Name) > settingsChannelNameLength { - return errors.Errorf("Channel name (%d characters) too long (max: %d)", len(in.Name), settingsChannelNameLength) + return errors.Errorf("channel name (%d characters) too long (max: %d)", len(in.Name), settingsChannelNameLength) } if len(in.Topic) > 0 && settingsChannelTopicLength > 0 && len(in.Topic) > settingsChannelTopicLength { - return errors.Errorf("Channel topic (%d characters) too long (max: %d)", len(in.Topic), settingsChannelTopicLength) + return errors.Errorf("channel topic (%d characters) too long (max: %d)", len(in.Topic), settingsChannelTopicLength) } // This is a fresh channel, just copy values @@ -342,24 +342,21 @@ func (svc *channel) Update(in *types.Channel) (ch *types.Channel, err error) { return } - if !ch.CanUpdate { - return errors.New("Not allowed to update this channel") + if !svc.prm.CanUpdateChannel(ch) { + return errors.New("not allowed to update this channel") } if in.Type.IsValid() && ch.Type != in.Type { - // @todo [SECURITY] check if user can update channel type to public - if in.Type == types.ChannelTypePublic && false { - return errors.New("Not allowed to change type of this channel to **public**") + if in.Type == types.ChannelTypePublic && !svc.prm.CanCreatePublicChannel() { + return errors.New("not allowed to change type of this channel to **public**") } - // @todo [SECURITY] check if user can create update channel type to private - if in.Type == types.ChannelTypePrivate && false { - return errors.New("Not allowed to change type of this channel to **private**") + if in.Type == types.ChannelTypePrivate && !svc.prm.CanCreatePrivateChannel() { + return errors.New("not allowed to change type of this channel to **private**") } - // @todo [SECURITY] check if user can update channel type to group - if in.Type == types.ChannelTypeGroup && false { - return errors.New("Not allowed to change type of this channel to **group**") + if in.Type == types.ChannelTypeGroup && !svc.prm.CanCreateGroupChannel() { + return errors.New("not allowed to change type of this channel to **group**") } changed = true @@ -368,11 +365,8 @@ func (svc *channel) Update(in *types.Channel) (ch *types.Channel, err error) { var chUpdatorId = repository.Identity(svc.ctx) if len(in.Name) > 0 && ch.Name != in.Name { - // @todo [SECURITY] can we change channel's name? - if false { - return errors.New("Not allowed to rename channel") - } else if settingsChannelNameLength > 0 && len(in.Name) > settingsChannelNameLength { - return errors.Errorf("Channel name (%d characters) too long (max: %d)", len(in.Name), settingsChannelNameLength) + if settingsChannelNameLength > 0 && len(in.Name) > settingsChannelNameLength { + return errors.Errorf("channel name (%d characters) too long (max: %d)", len(in.Name), settingsChannelNameLength) } else if ch.Name != "" { svc.scheduleSystemMessage(in, "<@%d> renamed channel **%s** (was: %s)", chUpdatorId, in.Name, ch.Name) } else { @@ -384,11 +378,8 @@ func (svc *channel) Update(in *types.Channel) (ch *types.Channel, err error) { } if len(in.Topic) > 0 && ch.Topic != in.Topic { - // @todo [SECURITY] can we change channel's topic? - if false { - return errors.New("Not allowed to change channel topic") - } else if settingsChannelTopicLength > 0 && len(in.Topic) > settingsChannelTopicLength { - return errors.Errorf("Channel topic (%d characters) too long (max: %d)", len(in.Topic), settingsChannelTopicLength) + if settingsChannelTopicLength > 0 && len(in.Topic) > settingsChannelTopicLength { + return errors.Errorf("channel topic (%d characters) too long (max: %d)", len(in.Topic), settingsChannelTopicLength) } else if ch.Topic != "" { svc.scheduleSystemMessage(in, "<@%d> changed channel topic: %s (was: %s)", chUpdatorId, in.Topic, ch.Topic) } else { @@ -421,12 +412,12 @@ func (svc *channel) Delete(ID uint64) (ch *types.Channel, err error) { return } - if !ch.CanDelete { - return errors.New("Not allowed to delete this channel") + if !svc.prm.CanDeleteChannel(ch) { + return errors.New("not allowed to delete this channel") } if ch.DeletedAt != nil { - return errors.New("Channel already deleted") + return errors.New("channel already deleted") } else { now := time.Now() ch.DeletedAt = &now @@ -455,12 +446,12 @@ func (svc *channel) Undelete(ID uint64) (ch *types.Channel, err error) { return } - if !ch.CanDelete { - return errors.New("Not allowed to undelete this channel") + if !svc.prm.CanUndeleteChannel(ch) { + return errors.New("not allowed to undelete this channel") } if ch.DeletedAt == nil { - return errors.New("Channel not deleted") + return errors.New("channel not deleted") } svc.scheduleSystemMessage(ch, "<@%d> undeleted this channel", userID) @@ -514,12 +505,12 @@ func (svc *channel) Archive(ID uint64) (ch *types.Channel, err error) { return } - if !ch.CanArchive { - return errors.New("Not allowed to archive this channel") + if !svc.prm.CanArchiveChannel(ch) { + return errors.New("not allowed to archive this channel") } if ch.ArchivedAt != nil { - return errors.New("Channel already archived") + return errors.New("channel already archived") } svc.scheduleSystemMessage(ch, "<@%d> archived this channel", userID) @@ -544,12 +535,12 @@ func (svc *channel) Unarchive(ID uint64) (ch *types.Channel, err error) { return } - if !ch.CanArchive { - return errors.New("Not allowed to unarchive this channel") + if !svc.prm.CanUnarchiveChannel(ch) { + return errors.New("not allowed to unarchive this channel") } if ch.ArchivedAt == nil { - return errors.New("Channel not archived") + return errors.New("channel not archived") } if err = svc.channel.UnarchiveChannelByID(ID); err != nil { @@ -580,11 +571,11 @@ func (svc *channel) InviteUser(channelID uint64, memberIDs ...uint64) (out types } if ch.Type == types.ChannelTypeGroup { - return nil, errors.New("Adding members to a group is not currently supported") + return nil, errors.New("adding members to a group is not currently supported") } - if !ch.CanChangeMembers { - return nil, errors.New("Not allowed to invite members") + if !svc.prm.CanManageChannelMembers(ch) { + return nil, errors.New("not allowed to invite members") } return out, svc.db.Transaction(func() (err error) { @@ -600,7 +591,7 @@ func (svc *channel) InviteUser(channelID uint64, memberIDs ...uint64) (out types for _, memberID := range memberIDs { user := users.FindByID(memberID) if user == nil { - return errors.New("Unexisting user") + return errors.New("unexisting user") } if e := existing.FindByUserID(memberID); e != nil { @@ -643,7 +634,7 @@ func (svc *channel) AddMember(channelID uint64, memberIDs ...uint64) (out types. } if ch.Type == types.ChannelTypeGroup { - return nil, errors.New("Adding members to a group is not currently supported") + return nil, errors.New("adding members to a group is not currently supported") } return out, svc.db.Transaction(func() (err error) { @@ -661,7 +652,7 @@ func (svc *channel) AddMember(channelID uint64, memberIDs ...uint64) (out types. user := users.FindByID(memberID) if user == nil { - return errors.New("Unexisting user") + return errors.New("unexisting user") } if e := existing.FindByUserID(memberID); e != nil { @@ -674,9 +665,8 @@ func (svc *channel) AddMember(channelID uint64, memberIDs ...uint64) (out types. } } - // @todo [SECURITY] implement proper checking - if !(ch.CanChangeMembers || memberID == userID && ch.Type == types.ChannelTypePublic) { - return errors.New("Not allowed to add members") + if !(svc.prm.CanManageChannelMembers(ch) || memberID == userID && ch.Type == types.ChannelTypePublic) { + return errors.New("not allowed to add members") } if !exists { @@ -744,9 +734,8 @@ func (svc *channel) DeleteMember(channelID uint64, memberIDs ...uint64) (err err continue } - // @todo [SECURITY] implement proper checking - if !(ch.CanChangeMembers || memberID == userID && ch.Type == types.ChannelTypePublic) { - return errors.New("Not allowed to add members") + if !(svc.prm.CanManageChannelMembers(ch) || memberID == userID && ch.Type == types.ChannelTypePublic) { + return errors.New("not allowed to add members") } if userID == memberID { @@ -825,24 +814,22 @@ func (svc *channel) sendChannelEvent(ch *types.Channel) (err error) { } func (svc *channel) setPermissionFlags(ch *types.Channel) (err error) { - var userID = repository.Identity(svc.ctx) + ch.CanJoin = svc.prm.CanJoinChannel(ch) + ch.CanPart = svc.prm.CanLeaveChannel(ch) + ch.CanObserve = svc.prm.CanReadChannel(ch) + ch.CanSendMessages = svc.prm.CanSendMessage(ch) - var ( - isMember = ch.Member != nil - isCreator = ch.CreatorID == userID - isOwner = isCreator || (isMember && ch.Member.Type == types.ChannelMembershipTypeOwner) - isPublic = ch.Type == types.ChannelTypePublic - ) + ch.CanDeleteMessages = svc.prm.CanDeleteMessages(ch) + ch.CanDeleteOwnMessages = svc.prm.CanDeleteOwnMessages(ch) + ch.CanUpdateMessages = svc.prm.CanUpdateMessages(ch) + ch.CanUpdateOwnMessages = svc.prm.CanUpdateOwnMessages(ch) + ch.CanChangeMembers = svc.prm.CanManageChannelMembers(ch) - ch.CanJoin = (ch.IsValid() && isPublic) || isOwner - ch.CanPart = isMember && ch.Type != types.ChannelTypeGroup - ch.CanObserve = (ch.IsValid() && isPublic) || isMember - ch.CanSendMessages = ch.CanObserve && isMember - ch.CanDeleteMessages = isOwner - ch.CanChangeMembers = isOwner - ch.CanUpdate = isOwner - ch.CanArchive = isOwner - ch.CanDelete = isOwner + ch.CanUpdate = svc.prm.CanUpdateChannel(ch) + ch.CanArchive = svc.prm.CanArchiveChannel(ch) + ch.CanUnarchive = svc.prm.CanUnarchiveChannel(ch) + ch.CanDelete = svc.prm.CanDeleteChannel(ch) + ch.CanUndelete = svc.prm.CanUndeleteChannel(ch) return nil } diff --git a/messaging/service/main_test.go b/messaging/service/main_test.go index 7986b4422..b1a52cc87 100644 --- a/messaging/service/main_test.go +++ b/messaging/service/main_test.go @@ -12,6 +12,7 @@ import ( "github.com/titpetric/factory" systemMigrate "github.com/crusttech/crust/system/db" + systemService "github.com/crusttech/crust/system/service" ) type mockDB struct{} @@ -53,6 +54,8 @@ func TestMain(m *testing.M) { } } + systemService.Init() + os.Exit(m.Run()) } diff --git a/messaging/service/permissions.go b/messaging/service/permissions.go index 28d273c17..77a2ced45 100644 --- a/messaging/service/permissions.go +++ b/messaging/service/permissions.go @@ -3,6 +3,7 @@ package service import ( "context" + "github.com/crusttech/crust/internal/auth" internalRules "github.com/crusttech/crust/internal/rules" "github.com/crusttech/crust/messaging/repository" "github.com/crusttech/crust/messaging/types" @@ -24,16 +25,20 @@ type ( CanGrantMessaging() bool CanCreatePublicChannel() bool CanCreatePrivateChannel() bool - CanCreateDirectChannel() bool + CanCreateGroupChannel() bool - CanUpdate(ch *types.Channel) bool - CanRead(ch *types.Channel) bool - CanJoin(ch *types.Channel) bool - CanLeave(ch *types.Channel) bool + CanUpdateChannel(ch *types.Channel) bool + CanReadChannel(ch *types.Channel) bool + CanJoinChannel(ch *types.Channel) bool + CanLeaveChannel(ch *types.Channel) bool + CanDeleteChannel(ch *types.Channel) bool + CanUndeleteChannel(ch *types.Channel) bool + CanArchiveChannel(ch *types.Channel) bool + CanUnarchiveChannel(ch *types.Channel) bool - CanManageMembers(ch *types.Channel) bool - CanManageWebhooks(ch *types.Channel) bool - CanManageAttachments(ch *types.Channel) bool + CanManageChannelMembers(ch *types.Channel) bool + CanManageChannelWebhooks(ch *types.Channel) bool + CanManageChannelAttachments(ch *types.Channel) bool CanSendMessage(ch *types.Channel) bool CanReplyMessage(ch *types.Channel) bool @@ -41,6 +46,8 @@ type ( CanAttachMessage(ch *types.Channel) bool CanUpdateOwnMessages(ch *types.Channel) bool CanUpdateMessages(ch *types.Channel) bool + CanDeleteOwnMessages(ch *types.Channel) bool + CanDeleteMessages(ch *types.Channel) bool CanReactMessage(ch *types.Channel) bool } ) @@ -77,40 +84,56 @@ func (p *permissions) CanCreatePrivateChannel() bool { return p.checkAccess("messaging", "channel.private.create") } -func (p *permissions) CanCreateDirectChannel() bool { - return p.checkAccess("messaging", "channel.direct.create") +func (p *permissions) CanCreateGroupChannel() bool { + return p.checkAccess("messaging", "channel.group.create") } -func (p *permissions) CanUpdate(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "update") +func (p *permissions) CanUpdateChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "update", p.isChannelOwnerFallback(ch)) } -func (p *permissions) CanRead(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "read") +func (p *permissions) CanReadChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "read", p.canReadFallback(ch)) } -func (p *permissions) CanJoin(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "join") +func (p *permissions) CanJoinChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "join", p.canJoinFallback(ch)) } -func (p *permissions) CanLeave(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "leave") +func (p *permissions) CanLeaveChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "leave", p.canLeaveFallback(ch)) } -func (p *permissions) CanManageMembers(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "members.manage") +func (p *permissions) CanArchiveChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "archive", p.isChannelOwnerFallback(ch)) } -func (p *permissions) CanManageWebhooks(ch *types.Channel) bool { +func (p *permissions) CanUnarchiveChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "unarchive", p.isChannelOwnerFallback(ch)) +} + +func (p *permissions) CanDeleteChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "delete", p.isChannelOwnerFallback(ch)) +} + +func (p *permissions) CanUndeleteChannel(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "undelete", p.isChannelOwnerFallback(ch)) +} + +func (p *permissions) CanManageChannelMembers(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "members.manage", p.isChannelOwnerFallback(ch)) +} + +func (p *permissions) CanManageChannelWebhooks(ch *types.Channel) bool { return p.checkAccess(ch.Resource().String(), "webhooks.manage") } -func (p *permissions) CanManageAttachments(ch *types.Channel) bool { +func (p *permissions) CanManageChannelAttachments(ch *types.Channel) bool { return p.checkAccess(ch.Resource().String(), "attachments.manage") } func (p *permissions) CanSendMessage(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "message.send") + return p.checkAccess(ch.Resource().String(), "message.send", p.canSendMessagesFallback(ch)) } func (p *permissions) CanReplyMessage(ch *types.Channel) bool { @@ -126,17 +149,83 @@ func (p *permissions) CanAttachMessage(ch *types.Channel) bool { } func (p *permissions) CanUpdateOwnMessages(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "message.update.own") + return p.checkAccess(ch.Resource().String(), "message.update.own", p.isChannelOwnerFallback(ch)) } func (p *permissions) CanUpdateMessages(ch *types.Channel) bool { - return p.checkAccess(ch.Resource().String(), "message.update.all") + return p.checkAccess(ch.Resource().String(), "message.update.all", p.isChannelOwnerFallback(ch)) +} + +func (p *permissions) CanDeleteOwnMessages(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "message.delete.own", p.isChannelOwnerFallback(ch)) +} + +func (p *permissions) CanDeleteMessages(ch *types.Channel) bool { + return p.checkAccess(ch.Resource().String(), "message.delete.all", p.isChannelOwnerFallback(ch)) } func (p *permissions) CanReactMessage(ch *types.Channel) bool { return p.checkAccess(ch.Resource().String(), "message.react") } +func (p permissions) canJoinFallback(ch *types.Channel) func() internalRules.Access { + return func() internalRules.Access { + userID := auth.GetIdentityFromContext(p.ctx).Identity() + + isMember := ch.Member != nil + isCreator := ch.CreatorID == userID + isOwner := isCreator || (isMember && ch.Member.Type == types.ChannelMembershipTypeOwner) + isPublic := ch.Type == types.ChannelTypePublic + + if (ch.IsValid() && isPublic) || isOwner { + return internalRules.Allow + } + return internalRules.Deny + } +} + +func (p permissions) canReadFallback(ch *types.Channel) func() internalRules.Access { + return func() internalRules.Access { + if (ch.IsValid() && ch.Type == types.ChannelTypePublic) || ch.Member != nil { + return internalRules.Allow + } + return internalRules.Deny + } +} + +func (p permissions) canSendMessagesFallback(ch *types.Channel) func() internalRules.Access { + return func() internalRules.Access { + if ch.IsValid() && ch.Type == types.ChannelTypePublic && ch.Member != nil { + return internalRules.Allow + } + return internalRules.Deny + } +} + +func (p permissions) canLeaveFallback(ch *types.Channel) func() internalRules.Access { + return func() internalRules.Access { + if ch.Member != nil && ch.Type != types.ChannelTypeGroup { + return internalRules.Allow + } + return internalRules.Deny + } +} + +func (p permissions) isChannelOwnerFallback(ch *types.Channel) func() internalRules.Access { + return func() internalRules.Access { + userID := auth.GetIdentityFromContext(p.ctx).Identity() + + isMember := ch.Member != nil + isCreator := ch.CreatorID == userID + isOwner := isCreator || (isMember && ch.Member.Type == types.ChannelMembershipTypeOwner) + + if isOwner { + return internalRules.Allow + } + return internalRules.Deny + } +} + func (p *permissions) checkAccess(resource string, operation string, fallbacks ...internalRules.CheckAccessFunc) bool { access := p.rules.Check(resource, operation, fallbacks...) if access == internalRules.Allow { diff --git a/messaging/service/permissions_test.go b/messaging/service/permissions_test.go index 497ea03cc..1f51d65f1 100644 --- a/messaging/service/permissions_test.go +++ b/messaging/service/permissions_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" + "github.com/pkg/errors" + "github.com/titpetric/factory" + "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/rules" . "github.com/crusttech/crust/internal/test" @@ -15,121 +18,126 @@ import ( ) func TestPermissions(t *testing.T) { - ctx := context.TODO() + // Create test user and role. + user := &systemTypes.User{ID: 1337} + role := &systemTypes.Role{ID: 1234567, Name: "Admins"} - // Create user with role and add it to context. - userSvc := systemService.User().With(ctx) - user := &systemTypes.User{ - Name: "John Doe", - Username: "johndoe", - SatosaID: "1234", - } - err := user.GeneratePassword("johndoe") - NoError(t, err, "expected no error generating password, got %v", err) + // Write user to context. + ctx := auth.SetIdentityToContext(context.Background(), user) - _, err = userSvc.Create(user) - NoError(t, err, "expected no error creating user, got %v", err) + // Connect do DB. + db := factory.Database.MustGet() - roleSvc := systemService.Role().With(ctx) - role := &systemTypes.Role{ - Name: "Test role v1", - } - role, err = roleSvc.Create(role) - NoError(t, err, "expected no error creating role, got %v", err) + // Run test with savepoint. + err := func() error { + db.Exec("SAVEPOINT permissions_test") - err = roleSvc.MemberAdd(role.ID, user.ID) - NoError(t, err, "expected no error adding user to role, got %v", err) + db.Insert("sys_user", user) + db.Insert("sys_role_member", systemTypes.RoleMember{RoleID: role.ID, UserID: user.ID}) - // Set Identity. - ctx = auth.SetIdentityToContext(ctx, user) + // Insert `grant` permission for `messaging`. + { + db := repository.DB(ctx) + resources := rules.NewResources(ctx, db) - // Insert `grant` permission for `messaging`. - { - db := repository.DB(ctx) - resources := rules.NewResources(ctx, db) + list := []rules.Rule{ + rules.Rule{Resource: "messaging", Operation: "grant", Value: rules.Allow}, + } - list := []rules.Rule{ - rules.Rule{Resource: "messaging", Operation: "grant", Value: rules.Allow}, + err := resources.Grant(role.ID, list) + NoError(t, err, "expected no error, got %v", err) } - err := resources.Grant(role.ID, list) - NoError(t, err, "expected no error, got %v", err) - } + // Generate services. + channelSvc := (&channel{ + usr: systemService.User(), + evl: Event(), + prm: Permissions(), + }).With(ctx) - // Generate services. - channelSvc := (&channel{ - usr: systemService.User(), - evl: Event(), - prm: Permissions(), - }).With(ctx) + permissionsSvc := Permissions().With(ctx) + systemRulesSvc := systemService.Rules().With(ctx) - permissionsSvc := Permissions().With(ctx) - systemRulesSvc := systemService.Rules().With(ctx) + // Remove `access` to messaging service. + { + list := []rules.Rule{ + rules.Rule{Resource: "messaging", Operation: "access", Value: rules.Deny}, + } + _, err := systemRulesSvc.Update(role.ID, list) + NoError(t, err, "expected no error, got %v", err) - // Test `access` to messaging service. - ret := permissionsSvc.CanAccessMessaging() - Assert(t, ret == false, "expected CanAccessMessaging == false, got %v", ret) - - // Add `access` to messaging service. - list := []rules.Rule{ - rules.Rule{Resource: "messaging", Operation: "access", Value: rules.Allow}, - } - _, err = systemRulesSvc.Update(role.ID, list) - NoError(t, err, "expected no error, got %v", err) - - // Test `access` to messaging service. - ret = permissionsSvc.CanAccessMessaging() - Assert(t, ret == true, "expected CanAccessMessaging == true, got %v", ret) - - // Create test channel. - ch := &types.Channel{ - Name: "TestChan", - Topic: "No topic", - } - ch, err = channelSvc.Create(ch) - NoError(t, err, "expected no error, got %v", err) - - // @Todo: add permission for create channel and test it. - - // Test CanRead permissions. [1 - no permission, 2 - allow] - { - ret = permissionsSvc.CanRead(ch) - Assert(t, ret == false, "expected CanRead == false, got %v") - - // Add [messaging:channel:*, read, allow] - list = []rules.Rule{ - rules.Rule{Resource: "messaging:channel:*", Operation: "read", Value: rules.Allow}, + // Test `access` to messaging service. + ret := permissionsSvc.CanAccessMessaging() + Assert(t, ret == false, "expected CanAccessMessaging == false, got %v", ret) } - _, err = systemRulesSvc.Update(role.ID, list) - NoError(t, err, "expected no error, got %v", err) - ret = permissionsSvc.CanRead(ch) - Assert(t, ret == true, "expected CanRead == true, got %v") - } + // Add `access` to messaging service. + { + list := []rules.Rule{ + rules.Rule{Resource: "messaging", Operation: "access", Value: rules.Allow}, + } + _, err := systemRulesSvc.Update(role.ID, list) + NoError(t, err, "expected no error, got %v", err) - // Test CanJoin permissions [1 - deny, 2 - allow for resourceID] - { - // Add [messaging:channel:*, join, deny] - list = []rules.Rule{ - rules.Rule{Resource: "messaging:channel:*", Operation: "join", Value: rules.Deny}, + // Test `access` to messaging service. + ret := permissionsSvc.CanAccessMessaging() + Assert(t, ret == true, "expected CanAccessMessaging == true, got %v", ret) } - _, err = systemRulesSvc.Update(role.ID, list) - NoError(t, err, "expected no error, got %v", err) - ret = permissionsSvc.CanJoin(ch) - Assert(t, ret == false, "expected CanJoin == false, got %v") - - // Add [messaging:channel:ID, join, allow] - list = []rules.Rule{ - rules.Rule{Resource: ch.Resource().String(), Operation: "join", Value: rules.Allow}, + // Create test channel. + ch := &types.Channel{ + Name: "TestChan", + Topic: "No topic", } - _, err = systemRulesSvc.Update(role.ID, list) + ch, err := channelSvc.Create(ch) NoError(t, err, "expected no error, got %v", err) - ret = permissionsSvc.CanJoin(ch) - Assert(t, ret == true, "expected CanJoin == true, got %v") - } + // @Todo: add permission for create channel and test it. - // Remove test channel. - channelSvc.Delete(ch.ID) + // Test CanReadChannel permissions. [1 - allow, 2 no permission] + { + ret := permissionsSvc.CanReadChannel(ch) + Assert(t, ret == true, "expected CanReadChannel == true, got %v", ret) + + // Add [messaging:channel:*, read, deny] + list := []rules.Rule{ + rules.Rule{Resource: "messaging:channel:*", Operation: "read", Value: rules.Deny}, + } + _, err = systemRulesSvc.Update(role.ID, list) + NoError(t, err, "expected no error, got %v", err) + + ret = permissionsSvc.CanReadChannel(ch) + Assert(t, ret == false, "expected CanReadChannel == false, got %v", ret) + } + + // Test CanJoinChannel permissions [1 - deny, 2 - allow for resourceID] + { + // Add [messaging:channel:*, join, deny] + list := []rules.Rule{ + rules.Rule{Resource: "messaging:channel:*", Operation: "join", Value: rules.Deny}, + } + _, err = systemRulesSvc.Update(role.ID, list) + NoError(t, err, "expected no error, got %v", err) + + ret := permissionsSvc.CanJoinChannel(ch) + Assert(t, ret == false, "expected CanJoinChannel == false, got %v") + + // Add [messaging:channel:ID, join, allow] + list = []rules.Rule{ + rules.Rule{Resource: ch.Resource().String(), Operation: "join", Value: rules.Allow}, + } + _, err = systemRulesSvc.Update(role.ID, list) + NoError(t, err, "expected no error, got %v", err) + + ret = permissionsSvc.CanJoinChannel(ch) + Assert(t, ret == true, "expected CanJoinChannel == true, got %v") + } + + // Remove test channel. + channelSvc.Delete(ch.ID) + return errors.New("Rollback") + }() + if err != nil { + db.Exec("ROLLBACK TO SAVEPOINT permissions_test") + } } diff --git a/messaging/types/channel.go b/messaging/types/channel.go index b015b333d..4da87c211 100644 --- a/messaging/types/channel.go +++ b/messaging/types/channel.go @@ -27,15 +27,20 @@ type ( LastMessageID uint64 `json:",omitempty" db:"rel_last_message"` - CanJoin bool `json:"-" db:"-"` - CanPart bool `json:"-" db:"-"` - CanObserve bool `json:"-" db:"-"` - CanSendMessages bool `json:"-" db:"-"` - CanDeleteMessages bool `json:"-" db:"-"` - CanChangeMembers bool `json:"-" db:"-"` - CanUpdate bool `json:"-" db:"-"` - CanArchive bool `json:"-" db:"-"` - CanDelete bool `json:"-" db:"-"` + CanJoin bool `json:"-" db:"-"` + CanPart bool `json:"-" db:"-"` + CanObserve bool `json:"-" db:"-"` + CanSendMessages bool `json:"-" db:"-"` + CanDeleteMessages bool `json:"-" db:"-"` + CanDeleteOwnMessages bool `json:"-" db:"-"` + CanUpdateMessages bool `json:"-" db:"-"` + CanUpdateOwnMessages bool `json:"-" db:"-"` + CanChangeMembers bool `json:"-" db:"-"` + CanUpdate bool `json:"-" db:"-"` + CanArchive bool `json:"-" db:"-"` + CanUnarchive bool `json:"-" db:"-"` + CanDelete bool `json:"-" db:"-"` + CanUndelete bool `json:"-" db:"-"` Member *ChannelMember `json:"-" db:"-"` Members []uint64 `json:"-" db:"-"` diff --git a/system/service/validation.go b/system/service/validation.go index eca255001..e0731174d 100644 --- a/system/service/validation.go +++ b/system/service/validation.go @@ -33,13 +33,17 @@ var ( "grant": true, "channel.public.create": true, "channel.private.create": true, - "channel.direct.create": true, + "channel.group.create": true, }, "messaging:channel": map[string]bool{ "update": true, "read": true, "join": true, "leave": true, + "delete": true, + "undelete": true, + "archive": true, + "unarchive": true, "members.manage": true, "webhooks.manage": true, "attachments.manage": true, @@ -49,6 +53,8 @@ var ( "message.attach": true, "message.update.own": true, "message.update.all": true, + "message.delete.own": true, + "message.delete.all": true, "message.react": true, }, "compose": map[string]bool{ From 601599abaeaaec032691f3fb4931b7ba7a4a89ed Mon Sep 17 00:00:00 2001 From: Mitja Zivkovic Date: Tue, 12 Mar 2019 01:00:34 +0100 Subject: [PATCH 2/2] upd(messaging): message permission checks --- messaging/service/message.go | 79 +++++++++++++++++--------------- messaging/service/permissions.go | 5 ++ 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/messaging/service/message.go b/messaging/service/message.go index f1ee3c3f0..fe530be4a 100644 --- a/messaging/service/message.go +++ b/messaging/service/message.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" + "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/payload" "github.com/crusttech/crust/messaging/repository" "github.com/crusttech/crust/messaging/types" @@ -29,6 +30,7 @@ type ( usr systemService.UserService evl EventService + prm PermissionsService } MessageService interface { @@ -68,6 +70,7 @@ func Message() MessageService { return &message{ usr: systemService.DefaultUser, evl: DefaultEvent, + prm: DefaultPermissions, } } @@ -79,6 +82,7 @@ func (svc *message) With(ctx context.Context) MessageService { usr: svc.usr.With(ctx), evl: svc.evl.With(ctx), + prm: svc.prm.With(ctx), attachment: repository.Attachment(ctx, db), channel: repository.Channel(ctx, db), @@ -91,11 +95,11 @@ func (svc *message) With(ctx context.Context) MessageService { } func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) { - // @todo get user from context - filter.CurrentUserID = repository.Identity(svc.ctx) + filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & read from this channel - _ = filter.ChannelID + if !svc.prm.CanReadChannelByID(filter.ChannelID) { + return nil, errors.New("not allowed to access channel") + } mm, err = svc.message.FindMessages(filter) if err != nil { @@ -106,11 +110,11 @@ func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err } func (svc *message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) { - // @todo get user from context - filter.CurrentUserID = repository.Identity(svc.ctx) + filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & read from this channel - _ = filter.ChannelID + if !svc.prm.CanReadChannelByID(filter.ChannelID) { + return nil, errors.New("not allowed to access channel") + } mm, err = svc.message.FindThreads(filter) if err != nil { @@ -129,13 +133,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error var mlen = len(in.Message) if mlen == 0 { - return nil, errors.Errorf("Refusing to create message without contents") + return nil, errors.Errorf("refusing to create message without contents") } else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength { - return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) + return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) } - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() in.UserID = currentUserID @@ -158,7 +161,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error } if !original.Type.IsRepliable() { - return errors.Errorf("Unable to reply on this message (type = %s)", original.Type) + return errors.Errorf("unable to reply on this message (type = %s)", original.Type) } // We do not want to have multi-level threads @@ -167,7 +170,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error in.ChannelID = original.ChannelID - // Increment counter, on struct and in repostiry. + // Increment counter, on struct and in repository. original.Replies++ if err = svc.message.IncReplyCount(original.ID); err != nil { return @@ -178,10 +181,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error } if in.ChannelID == 0 { - return errors.New("ChannelID missing") + return errors.New("channelID missing") } - // @todo [SECURITY] verify if current user can access & write to this channel + if !svc.prm.CanReadChannelByID(in.ChannelID) { + return errors.New("not allowed to access channel") + } if message, err = svc.message.CreateMessage(in); err != nil { return @@ -208,21 +213,21 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error var mlen = len(in.Message) if mlen == 0 { - return nil, errors.Errorf("Refusing to update message without contents") + return nil, errors.Errorf("refusing to update message without contents") } else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength { - return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) + return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) } - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & write to this channel - _ = currentUserID + if !svc.prm.CanReadChannelByID(in.ChannelID) { + return nil, errors.New("not allowed to access channel") + } return message, svc.db.Transaction(func() (err error) { message, err = svc.message.FindMessageByID(in.ID) if err != nil { - return errors.Wrap(err, "Could not load message for editing") + return errors.Wrap(err, "could not load message for editing") } if message.Message == in.Message { @@ -231,7 +236,7 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error } if message.UserID != currentUserID { - return errors.New("Not an owner") + return errors.New("not an owner") } // Allow message content to be changed @@ -250,10 +255,8 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error } func (svc *message) Delete(ID uint64) error { - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & write to this channel _ = currentUserID // @todo load current message @@ -269,6 +272,10 @@ func (svc *message) Delete(ID uint64) error { return err } + if !svc.prm.CanReadChannelByID(deletedMsg.ChannelID) { + return errors.New("not allowed to access channel") + } + if deletedMsg.ReplyTo > 0 { original, err = svc.message.FindMessageByID(deletedMsg.ReplyTo) if err != nil { @@ -386,10 +393,8 @@ func (svc *message) RemoveBookmark(messageID uint64) error { // React on a message with an emoji func (svc *message) flag(messageID uint64, flag string, remove bool) error { - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & write to this channel _ = currentUserID if strings.TrimSpace(flag) == "" { @@ -403,8 +408,6 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { var flagOwnerId = currentUserID var f *types.MessageFlag - // @todo [SECURITY] verify if current user can access & write to this channel - if flag == types.MessageFlagPinnedToChannel { // It does not matter how is the owner of the pin, flagOwnerId = 0 @@ -429,6 +432,10 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { return } + if !svc.prm.CanReadChannelByID(msg.ChannelID) { + return errors.New("not allowed to access channel") + } + if remove { err = svc.mflag.DeleteByID(f.ID) f.DeletedAt = timeNowPtr() @@ -450,7 +457,7 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { return }) - return errors.Wrap(err, "Can not flag/un-flag message") + return errors.Wrap(err, "can not flag/un-flag message") } func (svc *message) preload(mm types.MessageSet) (err error) { @@ -620,7 +627,7 @@ func (svc *message) extractMentions(m *types.Message) (mm types.MentionSet) { func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error { if existing, err := svc.mentions.FindByMessageIDs(messageID); err != nil { - return errors.Wrap(err, "Could not update mentions") + return errors.Wrap(err, "could not update mentions") } else if len(mm) > 0 { add, _, del := existing.Diff(mm) @@ -630,7 +637,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error }) if err != nil { - return errors.Wrap(err, "Could not create mentions") + return errors.Wrap(err, "could not create mentions") } err = del.Walk(func(m *types.Mention) error { @@ -638,7 +645,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error }) if err != nil { - return errors.Wrap(err, "Could not delete mentions") + return errors.Wrap(err, "could not delete mentions") } } else { return svc.mentions.DeleteByMessageID(messageID) diff --git a/messaging/service/permissions.go b/messaging/service/permissions.go index 77a2ced45..85008c19d 100644 --- a/messaging/service/permissions.go +++ b/messaging/service/permissions.go @@ -29,6 +29,7 @@ type ( CanUpdateChannel(ch *types.Channel) bool CanReadChannel(ch *types.Channel) bool + CanReadChannelByID(id uint64) bool CanJoinChannel(ch *types.Channel) bool CanLeaveChannel(ch *types.Channel) bool CanDeleteChannel(ch *types.Channel) bool @@ -96,6 +97,10 @@ func (p *permissions) CanReadChannel(ch *types.Channel) bool { return p.checkAccess(ch.Resource().String(), "read", p.canReadFallback(ch)) } +func (p *permissions) CanReadChannelByID(id uint64) bool { + return p.CanReadChannel(&types.Channel{ID: id}) +} + func (p *permissions) CanJoinChannel(ch *types.Channel) bool { return p.checkAccess(ch.Resource().String(), "join", p.canJoinFallback(ch)) }