upd(messaging): message permission checks
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/crusttech/crust/internal/payload"
|
||||
"github.com/crusttech/crust/messaging/repository"
|
||||
"github.com/crusttech/crust/messaging/types"
|
||||
@@ -29,6 +30,7 @@ type (
|
||||
|
||||
usr systemService.UserService
|
||||
evl EventService
|
||||
prm PermissionsService
|
||||
}
|
||||
|
||||
MessageService interface {
|
||||
@@ -68,6 +70,7 @@ func Message() MessageService {
|
||||
return &message{
|
||||
usr: systemService.DefaultUser,
|
||||
evl: DefaultEvent,
|
||||
prm: DefaultPermissions,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,6 +82,7 @@ func (svc *message) With(ctx context.Context) MessageService {
|
||||
|
||||
usr: svc.usr.With(ctx),
|
||||
evl: svc.evl.With(ctx),
|
||||
prm: svc.prm.With(ctx),
|
||||
|
||||
attachment: repository.Attachment(ctx, db),
|
||||
channel: repository.Channel(ctx, db),
|
||||
@@ -91,11 +95,11 @@ func (svc *message) With(ctx context.Context) MessageService {
|
||||
}
|
||||
|
||||
func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) {
|
||||
// @todo get user from context
|
||||
filter.CurrentUserID = repository.Identity(svc.ctx)
|
||||
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
// @todo verify if current user can access & read from this channel
|
||||
_ = filter.ChannelID
|
||||
if !svc.prm.CanReadChannelByID(filter.ChannelID) {
|
||||
return nil, errors.New("not allowed to access channel")
|
||||
}
|
||||
|
||||
mm, err = svc.message.FindMessages(filter)
|
||||
if err != nil {
|
||||
@@ -106,11 +110,11 @@ func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err
|
||||
}
|
||||
|
||||
func (svc *message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) {
|
||||
// @todo get user from context
|
||||
filter.CurrentUserID = repository.Identity(svc.ctx)
|
||||
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
// @todo verify if current user can access & read from this channel
|
||||
_ = filter.ChannelID
|
||||
if !svc.prm.CanReadChannelByID(filter.ChannelID) {
|
||||
return nil, errors.New("not allowed to access channel")
|
||||
}
|
||||
|
||||
mm, err = svc.message.FindThreads(filter)
|
||||
if err != nil {
|
||||
@@ -129,13 +133,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
|
||||
var mlen = len(in.Message)
|
||||
|
||||
if mlen == 0 {
|
||||
return nil, errors.Errorf("Refusing to create message without contents")
|
||||
return nil, errors.Errorf("refusing to create message without contents")
|
||||
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
|
||||
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
|
||||
return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
|
||||
}
|
||||
|
||||
// @todo get user from context
|
||||
var currentUserID uint64 = repository.Identity(svc.ctx)
|
||||
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
in.UserID = currentUserID
|
||||
|
||||
@@ -158,7 +161,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
|
||||
}
|
||||
|
||||
if !original.Type.IsRepliable() {
|
||||
return errors.Errorf("Unable to reply on this message (type = %s)", original.Type)
|
||||
return errors.Errorf("unable to reply on this message (type = %s)", original.Type)
|
||||
}
|
||||
|
||||
// We do not want to have multi-level threads
|
||||
@@ -167,7 +170,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
|
||||
|
||||
in.ChannelID = original.ChannelID
|
||||
|
||||
// Increment counter, on struct and in repostiry.
|
||||
// Increment counter, on struct and in repository.
|
||||
original.Replies++
|
||||
if err = svc.message.IncReplyCount(original.ID); err != nil {
|
||||
return
|
||||
@@ -178,10 +181,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error
|
||||
}
|
||||
|
||||
if in.ChannelID == 0 {
|
||||
return errors.New("ChannelID missing")
|
||||
return errors.New("channelID missing")
|
||||
}
|
||||
|
||||
// @todo [SECURITY] verify if current user can access & write to this channel
|
||||
if !svc.prm.CanReadChannelByID(in.ChannelID) {
|
||||
return errors.New("not allowed to access channel")
|
||||
}
|
||||
|
||||
if message, err = svc.message.CreateMessage(in); err != nil {
|
||||
return
|
||||
@@ -208,21 +213,21 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error
|
||||
var mlen = len(in.Message)
|
||||
|
||||
if mlen == 0 {
|
||||
return nil, errors.Errorf("Refusing to update message without contents")
|
||||
return nil, errors.Errorf("refusing to update message without contents")
|
||||
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
|
||||
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
|
||||
return nil, errors.Errorf("message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
|
||||
}
|
||||
|
||||
// @todo get user from context
|
||||
var currentUserID uint64 = repository.Identity(svc.ctx)
|
||||
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
// @todo verify if current user can access & write to this channel
|
||||
_ = currentUserID
|
||||
if !svc.prm.CanReadChannelByID(in.ChannelID) {
|
||||
return nil, errors.New("not allowed to access channel")
|
||||
}
|
||||
|
||||
return message, svc.db.Transaction(func() (err error) {
|
||||
message, err = svc.message.FindMessageByID(in.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Could not load message for editing")
|
||||
return errors.Wrap(err, "could not load message for editing")
|
||||
}
|
||||
|
||||
if message.Message == in.Message {
|
||||
@@ -231,7 +236,7 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error
|
||||
}
|
||||
|
||||
if message.UserID != currentUserID {
|
||||
return errors.New("Not an owner")
|
||||
return errors.New("not an owner")
|
||||
}
|
||||
|
||||
// Allow message content to be changed
|
||||
@@ -250,10 +255,8 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error
|
||||
}
|
||||
|
||||
func (svc *message) Delete(ID uint64) error {
|
||||
// @todo get user from context
|
||||
var currentUserID uint64 = repository.Identity(svc.ctx)
|
||||
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
// @todo verify if current user can access & write to this channel
|
||||
_ = currentUserID
|
||||
|
||||
// @todo load current message
|
||||
@@ -269,6 +272,10 @@ func (svc *message) Delete(ID uint64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if !svc.prm.CanReadChannelByID(deletedMsg.ChannelID) {
|
||||
return errors.New("not allowed to access channel")
|
||||
}
|
||||
|
||||
if deletedMsg.ReplyTo > 0 {
|
||||
original, err = svc.message.FindMessageByID(deletedMsg.ReplyTo)
|
||||
if err != nil {
|
||||
@@ -386,10 +393,8 @@ func (svc *message) RemoveBookmark(messageID uint64) error {
|
||||
|
||||
// React on a message with an emoji
|
||||
func (svc *message) flag(messageID uint64, flag string, remove bool) error {
|
||||
// @todo get user from context
|
||||
var currentUserID uint64 = repository.Identity(svc.ctx)
|
||||
var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
// @todo verify if current user can access & write to this channel
|
||||
_ = currentUserID
|
||||
|
||||
if strings.TrimSpace(flag) == "" {
|
||||
@@ -403,8 +408,6 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error {
|
||||
var flagOwnerId = currentUserID
|
||||
var f *types.MessageFlag
|
||||
|
||||
// @todo [SECURITY] verify if current user can access & write to this channel
|
||||
|
||||
if flag == types.MessageFlagPinnedToChannel {
|
||||
// It does not matter how is the owner of the pin,
|
||||
flagOwnerId = 0
|
||||
@@ -429,6 +432,10 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error {
|
||||
return
|
||||
}
|
||||
|
||||
if !svc.prm.CanReadChannelByID(msg.ChannelID) {
|
||||
return errors.New("not allowed to access channel")
|
||||
}
|
||||
|
||||
if remove {
|
||||
err = svc.mflag.DeleteByID(f.ID)
|
||||
f.DeletedAt = timeNowPtr()
|
||||
@@ -450,7 +457,7 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error {
|
||||
return
|
||||
})
|
||||
|
||||
return errors.Wrap(err, "Can not flag/un-flag message")
|
||||
return errors.Wrap(err, "can not flag/un-flag message")
|
||||
}
|
||||
|
||||
func (svc *message) preload(mm types.MessageSet) (err error) {
|
||||
@@ -620,7 +627,7 @@ func (svc *message) extractMentions(m *types.Message) (mm types.MentionSet) {
|
||||
|
||||
func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error {
|
||||
if existing, err := svc.mentions.FindByMessageIDs(messageID); err != nil {
|
||||
return errors.Wrap(err, "Could not update mentions")
|
||||
return errors.Wrap(err, "could not update mentions")
|
||||
} else if len(mm) > 0 {
|
||||
add, _, del := existing.Diff(mm)
|
||||
|
||||
@@ -630,7 +637,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Could not create mentions")
|
||||
return errors.Wrap(err, "could not create mentions")
|
||||
}
|
||||
|
||||
err = del.Walk(func(m *types.Mention) error {
|
||||
@@ -638,7 +645,7 @@ func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Could not delete mentions")
|
||||
return errors.Wrap(err, "could not delete mentions")
|
||||
}
|
||||
} else {
|
||||
return svc.mentions.DeleteByMessageID(messageID)
|
||||
|
||||
@@ -29,6 +29,7 @@ type (
|
||||
|
||||
CanUpdateChannel(ch *types.Channel) bool
|
||||
CanReadChannel(ch *types.Channel) bool
|
||||
CanReadChannelByID(id uint64) bool
|
||||
CanJoinChannel(ch *types.Channel) bool
|
||||
CanLeaveChannel(ch *types.Channel) bool
|
||||
CanDeleteChannel(ch *types.Channel) bool
|
||||
@@ -96,6 +97,10 @@ func (p *permissions) CanReadChannel(ch *types.Channel) bool {
|
||||
return p.checkAccess(ch.Resource().String(), "read", p.canReadFallback(ch))
|
||||
}
|
||||
|
||||
func (p *permissions) CanReadChannelByID(id uint64) bool {
|
||||
return p.CanReadChannel(&types.Channel{ID: id})
|
||||
}
|
||||
|
||||
func (p *permissions) CanJoinChannel(ch *types.Channel) bool {
|
||||
return p.checkAccess(ch.Resource().String(), "join", p.canJoinFallback(ch))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user