3
0

upd(messaging): message permission checks

This commit is contained in:
Mitja Zivkovic
2019-03-12 01:00:34 +01:00
parent 9f494a37f9
commit 601599abae
2 changed files with 48 additions and 36 deletions

View File

@@ -7,6 +7,7 @@ import (
"github.com/pkg/errors"
"github.com/crusttech/crust/internal/auth"
"github.com/crusttech/crust/internal/payload"
"github.com/crusttech/crust/messaging/repository"
"github.com/crusttech/crust/messaging/types"
@@ -29,6 +30,7 @@ type (
usr systemService.UserService
evl EventService
prm PermissionsService
}
MessageService interface {
@@ -68,6 +70,7 @@ func Message() MessageService {
return &message{
usr: systemService.DefaultUser,
evl: DefaultEvent,
prm: DefaultPermissions,
}
}
@@ -79,6 +82,7 @@ func (svc *message) With(ctx context.Context) MessageService {
usr: svc.usr.With(ctx),
evl: svc.evl.With(ctx),
prm: svc.prm.With(ctx),
attachment: repository.Attachment(ctx, db),
channel: repository.Channel(ctx, db),
@@ -91,11 +95,11 @@ func (svc *message) With(ctx context.Context) MessageService {
}
func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) {
// @todo get user from context
filter.CurrentUserID = repository.Identity(svc.ctx)
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo verify if current user can access & read from this channel
_ = filter.ChannelID
if !svc.prm.CanReadChannelByID(filter.ChannelID) {
return nil, errors.New("not allowed to access channel")
}
mm, err = svc.message.FindMessages(filter)
if err != nil {
@@ -106,11 +110,11 @@ func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err
}
func (svc *message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) {
// @todo get user from context
filter.CurrentUserID = repository.Identity(svc.ctx)
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo verify if current user can access & read from this channel
_ = filter.ChannelID
if !svc.prm.CanReadChannelByID(filter.ChannelID) {
return nil, errors.New("not allowed to access channel")
}
mm, err = svc.message.FindThreads(filter)
if err != nil {
@@ -129,13 +133,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
var mlen = len(in.Message)
if mlen == 0 {
return nil, errors.Errorf("Refusing to create message without contents")
return nil, errors.Errorf("refusing to create message without contents")
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
}
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
in.UserID = currentUserID
@@ -158,7 +161,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
}
if !original.Type.IsRepliable() {
return errors.Errorf("Unable to reply on this message (type = %s)", original.Type)
return errors.Errorf("unable to reply on this message (type = %s)", original.Type)
}
// We do not want to have multi-level threads
@@ -167,7 +170,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
in.ChannelID = original.ChannelID
// Increment counter, on struct and in repostiry.
// Increment counter, on struct and in repository.
original.Replies++
if err = svc.message.IncReplyCount(original.ID); err != nil {
return
@@ -178,10 +181,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
}
if in.ChannelID == 0 {
return errors.New("ChannelID missing")
return errors.New("channelID missing")
}
// @todo [SECURITY] verify if current user can access & write to this channel
if !svc.prm.CanReadChannelByID(in.ChannelID) {
return errors.New("not allowed to access channel")
}
if message, err = svc.message.CreateMessage(in); err != nil {
return
@@ -208,21 +213,21 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error
var mlen = len(in.Message)
if mlen == 0 {
return nil, errors.Errorf("Refusing to update message without contents")
return nil, errors.Errorf("refusing to update message without contents")
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
}
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo verify if current user can access & write to this channel
_ = currentUserID
if !svc.prm.CanReadChannelByID(in.ChannelID) {
return nil, errors.New("not allowed to access channel")
}
return message, svc.db.Transaction(func() (err error) {
message, err = svc.message.FindMessageByID(in.ID)
if err != nil {
return errors.Wrap(err, "Could not load message for editing")
return errors.Wrap(err, "could not load message for editing")
}
if message.Message == in.Message {
@@ -231,7 +236,7 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error
}
if message.UserID != currentUserID {
return errors.New("Not an owner")
return errors.New("not an owner")
}
// Allow message content to be changed
@@ -250,10 +255,8 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error
}
func (svc *message) Delete(ID uint64) error {
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo verify if current user can access & write to this channel
_ = currentUserID
// @todo load current message
@@ -269,6 +272,10 @@ func (svc *message) Delete(ID uint64) error {
return err
}
if !svc.prm.CanReadChannelByID(deletedMsg.ChannelID) {
return errors.New("not allowed to access channel")
}
if deletedMsg.ReplyTo > 0 {
original, err = svc.message.FindMessageByID(deletedMsg.ReplyTo)
if err != nil {
@@ -386,10 +393,8 @@ func (svc *message) RemoveBookmark(messageID uint64) error {
// React on a message with an emoji
func (svc *message) flag(messageID uint64, flag string, remove bool) error {
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo verify if current user can access & write to this channel
_ = currentUserID
if strings.TrimSpace(flag) == "" {
@@ -403,8 +408,6 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error {
var flagOwnerId = currentUserID
var f *types.MessageFlag
// @todo [SECURITY] verify if current user can access & write to this channel
if flag == types.MessageFlagPinnedToChannel {
// It does not matter how is the owner of the pin,
flagOwnerId = 0
@@ -429,6 +432,10 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error {
return
}
if !svc.prm.CanReadChannelByID(msg.ChannelID) {
return errors.New("not allowed to access channel")
}
if remove {
err = svc.mflag.DeleteByID(f.ID)
f.DeletedAt = timeNowPtr()
@@ -450,7 +457,7 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error {
return
})
return errors.Wrap(err, "Can not flag/un-flag message")
return errors.Wrap(err, "can not flag/un-flag message")
}
func (svc *message) preload(mm types.MessageSet) (err error) {
@@ -620,7 +627,7 @@ func (svc *message) extractMentions(m *types.Message) (mm types.MentionSet) {
func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error {
if existing, err := svc.mentions.FindByMessageIDs(messageID); err != nil {
return errors.Wrap(err, "Could not update mentions")
return errors.Wrap(err, "could not update mentions")
} else if len(mm) > 0 {
add, _, del := existing.Diff(mm)
@@ -630,7 +637,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error
})
if err != nil {
return errors.Wrap(err, "Could not create mentions")
return errors.Wrap(err, "could not create mentions")
}
err = del.Walk(func(m *types.Mention) error {
@@ -638,7 +645,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error
})
if err != nil {
return errors.Wrap(err, "Could not delete mentions")
return errors.Wrap(err, "could not delete mentions")
}
} else {
return svc.mentions.DeleteByMessageID(messageID)

View File

@@ -29,6 +29,7 @@ type (
CanUpdateChannel(ch *types.Channel) bool
CanReadChannel(ch *types.Channel) bool
CanReadChannelByID(id uint64) bool
CanJoinChannel(ch *types.Channel) bool
CanLeaveChannel(ch *types.Channel) bool
CanDeleteChannel(ch *types.Channel) bool
@@ -96,6 +97,10 @@ func (p *permissions) CanReadChannel(ch *types.Channel) bool {
return p.checkAccess(ch.Resource().String(), "read", p.canReadFallback(ch))
}
func (p *permissions) CanReadChannelByID(id uint64) bool {
return p.CanReadChannel(&types.Channel{ID: id})
}
func (p *permissions) CanJoinChannel(ch *types.Channel) bool {
return p.checkAccess(ch.Resource().String(), "join", p.canJoinFallback(ch))
}