diff --git a/messaging/service/channel.go b/messaging/service/channel.go index 3e1b74d65..6b7d01be9 100644 --- a/messaging/service/channel.go +++ b/messaging/service/channel.go @@ -665,8 +665,10 @@ func (svc *channel) AddMember(channelID uint64, memberIDs ...uint64) (out types. } } - if !(svc.prm.CanManageChannelMembers(ch) || memberID == userID && ch.Type == types.ChannelTypePublic) { - return errors.New("not allowed to add members") + if memberID == userID && !svc.prm.CanJoinChannel(ch) { + return errors.New("not allowed to join") + } else if !svc.prm.CanManageChannelMembers(ch) { + return errors.New("not allowed to add channel members") } if !exists { @@ -734,8 +736,10 @@ func (svc *channel) DeleteMember(channelID uint64, memberIDs ...uint64) (err err continue } - if !(svc.prm.CanManageChannelMembers(ch) || memberID == userID && ch.Type == types.ChannelTypePublic) { - return errors.New("not allowed to add members") + if memberID == userID && !svc.prm.CanJoinChannel(ch) { + return errors.New("not allowed to leave") + } else if !svc.prm.CanManageChannelMembers(ch) { + return errors.New("not allowed to remove channel members") } if userID == memberID { diff --git a/messaging/service/message.go b/messaging/service/message.go index 6da3c611b..e8c8131c3 100644 --- a/messaging/service/message.go +++ b/messaging/service/message.go @@ -145,6 +145,7 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error return message, svc.db.Transaction(func() (err error) { // Broadcast queue var bq = types.MessageSet{} + var ch *types.Channel if in.ReplyTo > 0 { var original *types.Message @@ -182,10 +183,12 @@ func (svc *message) Create(in *types.Message) (message *types.Message, err error if in.ChannelID == 0 { return errors.New("channelID missing") - } - - if !svc.prm.CanReadChannelByID(in.ChannelID) { - return errors.New("not allowed to access channel") + } else if ch, err = svc.channel.FindChannelByID(in.ChannelID); err != nil { + return + } else if in.ReplyTo > 0 && !svc.prm.CanReplyMessage(ch) { + return errors.New("not allowed to reply in this channel") + } else if !svc.prm.CanSendMessage(ch) { + return errors.New("not allowed to send messages in this channel") } if message, err = svc.message.CreateMessage(in); err != nil { @@ -219,6 +222,7 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error } var currentUserID = auth.GetIdentityFromContext(svc.ctx).Identity() + var ch *types.Channel if !svc.prm.CanReadChannelByID(in.ChannelID) { return nil, errors.New("not allowed to access channel") @@ -235,8 +239,12 @@ func (svc *message) Update(in *types.Message) (message *types.Message, err error return nil } - if message.UserID != currentUserID { - return errors.New("not an owner") + if ch, err = svc.channel.FindChannelByID(message.ChannelID); err != nil { + return + } else if message.UserID == currentUserID && !svc.prm.CanUpdateOwnMessages(ch) { + return errors.New("not allowed to edit your messages in this channel") + } else if !svc.prm.CanUpdateMessages(ch) { + return errors.New("not allowed to edit messages in this channel") } // Allow message content to be changed @@ -259,13 +267,11 @@ func (svc *message) Delete(ID uint64) error { _ = currentUserID - // @todo load current message - // @todo verify ownership - return svc.db.Transaction(func() (err error) { // Broadcast queue var bq = types.MessageSet{} var deletedMsg, original *types.Message + var ch *types.Channel deletedMsg, err = svc.message.FindMessageByID(ID) if err != nil { @@ -282,6 +288,14 @@ func (svc *message) Delete(ID uint64) error { return err } + if ch, err = svc.channel.FindChannelByID(original.ChannelID); err != nil { + return + } else if original.UserID == currentUserID && !svc.prm.CanUpdateOwnMessages(ch) { + return errors.New("not allowed to delete your messages in this channel") + } else if !svc.prm.CanUpdateMessages(ch) { + return errors.New("not allowed to delete messages in this channel") + } + // This is a reply to another message, decrease reply counter on the original, on struct and in the // repository if original.Replies > 0 { @@ -407,6 +421,8 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { err := svc.db.Transaction(func() (err error) { var flagOwnerId = currentUserID var f *types.MessageFlag + var msg *types.Message + var ch *types.Channel if flag == types.MessageFlagPinnedToChannel { // It does not matter how is the owner of the pin, @@ -425,15 +441,14 @@ func (svc *message) flag(messageID uint64, flag string, remove bool) error { return } - // Check message - var msg *types.Message - msg, err = svc.message.FindMessageByID(messageID) - if err != nil { + if msg, err = svc.message.FindMessageByID(messageID); err != nil { return - } - - if !svc.prm.CanReadChannelByID(msg.ChannelID) { + } else if ch, err = svc.channel.FindChannelByID(msg.ChannelID); err != nil { + return + } else if !svc.prm.CanReadChannel(ch) { return errors.New("not allowed to access channel") + } else if f.IsReaction() && !svc.prm.CanReactMessage(ch) { + return errors.New("not allowed to react on channels") } if remove {