From 601599abaeaaec032691f3fb4931b7ba7a4a89ed Mon Sep 17 00:00:00 2001 From: Mitja Zivkovic Date: Tue, 12 Mar 2019 01:00:34 +0100 Subject: [PATCH] upd(messaging): message permission checks --- messaging/service/message.go | 79 +++++++++++++++++--------------- messaging/service/permissions.go | 5 ++ 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/messaging/service/message.go b/messaging/service/message.go index f1ee3c3f0..fe530be4a 100644 --- a/messaging/service/message.go +++ b/messaging/service/message.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" + "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/payload" "github.com/crusttech/crust/messaging/repository" "github.com/crusttech/crust/messaging/types" @@ -29,6 +30,7 @@ type ( usr systemService.UserService evl EventService + prm PermissionsService } MessageService interface { @@ -68,6 +70,7 @@ func Message() MessageService { return &message{ usr: systemService.DefaultUser, evl: DefaultEvent, + prm: DefaultPermissions, } } @@ -79,6 +82,7 @@ func (svc *message) With(ctx context.Context) MessageService { usr: svc.usr.With(ctx), evl: svc.evl.With(ctx), + prm: svc.prm.With(ctx), attachment: repository.Attachment(ctx, db), channel: repository.Channel(ctx, db), @@ -91,11 +95,11 @@ func (svc *message) With(ctx context.Context) MessageService { } func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) { - // @todo get user from context - filter.CurrentUserID = repository.Identity(svc.ctx) + filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & read from this channel - _ = filter.ChannelID + if !svc.prm.CanReadChannelByID(filter.ChannelID) { + return nil, errors.New("not allowed to access channel") + } mm, err = svc.message.FindMessages(filter) if err != nil { @@ -106,11 +110,11 @@ func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err } func (svc *message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) { - // @todo get user from context - filter.CurrentUserID = repository.Identity(svc.ctx) + filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & read from this channel - _ = filter.ChannelID + if !svc.prm.CanReadChannelByID(filter.ChannelID) { + return nil, errors.New("not allowed to access channel") + } mm, err = svc.message.FindThreads(filter) if err != nil { @@ -129,13 +133,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error var mlen = len(in.Message) if mlen == 0 { - return nil, errors.Errorf("Refusing to create message without contents") + return nil, errors.Errorf("refusing to create message without contents") } else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength { - return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) + return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) } - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() in.UserID = currentUserID @@ -158,7 +161,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error } if !original.Type.IsRepliable() { - return errors.Errorf("Unable to reply on this message (type = %s)", original.Type) + return errors.Errorf("unable to reply on this message (type = %s)", original.Type) } // We do not want to have multi-level threads @@ -167,7 +170,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error in.ChannelID = original.ChannelID - // Increment counter, on struct and in repostiry. + // Increment counter, on struct and in repository. original.Replies++ if err = svc.message.IncReplyCount(original.ID); err != nil { return @@ -178,10 +181,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error } if in.ChannelID == 0 { - return errors.New("ChannelID missing") + return errors.New("channelID missing") } - // @todo [SECURITY] verify if current user can access & write to this channel + if !svc.prm.CanReadChannelByID(in.ChannelID) { + return errors.New("not allowed to access channel") + } if message, err = svc.message.CreateMessage(in); err != nil { return @@ -208,21 +213,21 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error var mlen = len(in.Message) if mlen == 0 { - return nil, errors.Errorf("Refusing to update message without contents") + return nil, errors.Errorf("refusing to update message without contents") } else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength { - return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) + return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength) } - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & write to this channel - _ = currentUserID + if !svc.prm.CanReadChannelByID(in.ChannelID) { + return nil, errors.New("not allowed to access channel") + } return message, svc.db.Transaction(func() (err error) { message, err = svc.message.FindMessageByID(in.ID) if err != nil { - return errors.Wrap(err, "Could not load message for editing") + return errors.Wrap(err, "could not load message for editing") } if message.Message == in.Message { @@ -231,7 +236,7 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error } if message.UserID != currentUserID { - return errors.New("Not an owner") + return errors.New("not an owner") } // Allow message content to be changed @@ -250,10 +255,8 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error } func (svc *message) Delete(ID uint64) error { - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & write to this channel _ = currentUserID // @todo load current message @@ -269,6 +272,10 @@ func (svc *message) Delete(ID uint64) error { return err } + if !svc.prm.CanReadChannelByID(deletedMsg.ChannelID) { + return errors.New("not allowed to access channel") + } + if deletedMsg.ReplyTo > 0 { original, err = svc.message.FindMessageByID(deletedMsg.ReplyTo) if err != nil { @@ -386,10 +393,8 @@ func (svc *message) RemoveBookmark(messageID uint64) error { // React on a message with an emoji func (svc *message) flag(messageID uint64, flag string, remove bool) error { - // @todo get user from context - var currentUserID uint64 = repository.Identity(svc.ctx) + var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() - // @todo verify if current user can access & write to this channel _ = currentUserID if strings.TrimSpace(flag) == "" { @@ -403,8 +408,6 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { var flagOwnerId = currentUserID var f *types.MessageFlag - // @todo [SECURITY] verify if current user can access & write to this channel - if flag == types.MessageFlagPinnedToChannel { // It does not matter how is the owner of the pin, flagOwnerId = 0 @@ -429,6 +432,10 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { return } + if !svc.prm.CanReadChannelByID(msg.ChannelID) { + return errors.New("not allowed to access channel") + } + if remove { err = svc.mflag.DeleteByID(f.ID) f.DeletedAt = timeNowPtr() @@ -450,7 +457,7 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { return }) - return errors.Wrap(err, "Can not flag/un-flag message") + return errors.Wrap(err, "can not flag/un-flag message") } func (svc *message) preload(mm types.MessageSet) (err error) { @@ -620,7 +627,7 @@ func (svc *message) extractMentions(m *types.Message) (mm types.MentionSet) { func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error { if existing, err := svc.mentions.FindByMessageIDs(messageID); err != nil { - return errors.Wrap(err, "Could not update mentions") + return errors.Wrap(err, "could not update mentions") } else if len(mm) > 0 { add, _, del := existing.Diff(mm) @@ -630,7 +637,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error }) if err != nil { - return errors.Wrap(err, "Could not create mentions") + return errors.Wrap(err, "could not create mentions") } err = del.Walk(func(m *types.Mention) error { @@ -638,7 +645,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error }) if err != nil { - return errors.Wrap(err, "Could not delete mentions") + return errors.Wrap(err, "could not delete mentions") } } else { return svc.mentions.DeleteByMessageID(messageID) diff --git a/messaging/service/permissions.go b/messaging/service/permissions.go index 77a2ced45..85008c19d 100644 --- a/messaging/service/permissions.go +++ b/messaging/service/permissions.go @@ -29,6 +29,7 @@ type ( CanUpdateChannel(ch *types.Channel) bool CanReadChannel(ch *types.Channel) bool + CanReadChannelByID(id uint64) bool CanJoinChannel(ch *types.Channel) bool CanLeaveChannel(ch *types.Channel) bool CanDeleteChannel(ch *types.Channel) bool @@ -96,6 +97,10 @@ func (p *permissions) CanReadChannel(ch *types.Channel) bool { return p.checkAccess(ch.Resource().String(), "read", p.canReadFallback(ch)) } +func (p *permissions) CanReadChannelByID(id uint64) bool { + return p.CanReadChannel(&types.Channel{ID: id}) +} + func (p *permissions) CanJoinChannel(ch *types.Channel) bool { return p.checkAccess(ch.Resource().String(), "join", p.canJoinFallback(ch)) }