636 lines
15 KiB
Go
636 lines
15 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/crusttech/crust/internal/payload"
|
|
"github.com/crusttech/crust/sam/repository"
|
|
"github.com/crusttech/crust/sam/types"
|
|
systemService "github.com/crusttech/crust/system/service"
|
|
systemTypes "github.com/crusttech/crust/system/types"
|
|
)
|
|
|
|
type (
|
|
message struct {
|
|
db db
|
|
ctx context.Context
|
|
|
|
attachment repository.AttachmentRepository
|
|
channel repository.ChannelRepository
|
|
cmember repository.ChannelMemberRepository
|
|
unreads repository.UnreadRepository
|
|
message repository.MessageRepository
|
|
mflag repository.MessageFlagRepository
|
|
mentions repository.MentionRepository
|
|
|
|
usr systemService.UserService
|
|
evl EventService
|
|
}
|
|
|
|
MessageService interface {
|
|
With(ctx context.Context) MessageService
|
|
|
|
Find(filter *types.MessageFilter) (types.MessageSet, error)
|
|
FindThreads(filter *types.MessageFilter) (types.MessageSet, error)
|
|
|
|
Create(messages *types.Message) (*types.Message, error)
|
|
Update(messages *types.Message) (*types.Message, error)
|
|
|
|
React(messageID uint64, reaction string) error
|
|
RemoveReaction(messageID uint64, reaction string) error
|
|
|
|
MarkAsUnread(messageID uint64) error
|
|
|
|
Pin(messageID uint64) error
|
|
RemovePin(messageID uint64) error
|
|
|
|
Bookmark(messageID uint64) error
|
|
RemoveBookmark(messageID uint64) error
|
|
|
|
Delete(ID uint64) error
|
|
}
|
|
)
|
|
|
|
const (
|
|
settingsMessageBodyLength = 0
|
|
mentionRE = `<([@#])(\d+)((?:\s)([^>]+))?>`
|
|
)
|
|
|
|
var (
|
|
mentionsFinder = regexp.MustCompile(mentionRE)
|
|
)
|
|
|
|
func Message() MessageService {
|
|
return &message{
|
|
usr: systemService.DefaultUser,
|
|
evl: DefaultEvent,
|
|
}
|
|
}
|
|
|
|
func (svc *message) With(ctx context.Context) MessageService {
|
|
db := repository.DB(ctx)
|
|
return &message{
|
|
db: db,
|
|
ctx: ctx,
|
|
|
|
usr: svc.usr.With(ctx),
|
|
evl: svc.evl.With(ctx),
|
|
|
|
attachment: repository.Attachment(ctx, db),
|
|
channel: repository.Channel(ctx, db),
|
|
cmember: repository.ChannelMember(ctx, db),
|
|
unreads: repository.ChannelView(ctx, db),
|
|
message: repository.Message(ctx, db),
|
|
mflag: repository.MessageFlag(ctx, db),
|
|
mentions: repository.Mention(ctx, db),
|
|
}
|
|
}
|
|
|
|
func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) {
|
|
// @todo get user from context
|
|
filter.CurrentUserID = repository.Identity(svc.ctx)
|
|
|
|
// @todo verify if current user can access & read from this channel
|
|
_ = filter.ChannelID
|
|
|
|
mm, err = svc.message.FindMessages(filter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mm, svc.preload(mm)
|
|
}
|
|
|
|
func (svc *message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) {
|
|
// @todo get user from context
|
|
filter.CurrentUserID = repository.Identity(svc.ctx)
|
|
|
|
// @todo verify if current user can access & read from this channel
|
|
_ = filter.ChannelID
|
|
|
|
mm, err = svc.message.FindThreads(filter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mm, svc.preload(mm)
|
|
}
|
|
|
|
func (svc *message) Create(in *types.Message) (message *types.Message, err error) {
|
|
if in == nil {
|
|
in = &types.Message{}
|
|
}
|
|
|
|
in.Message = strings.TrimSpace(in.Message)
|
|
var mlen = len(in.Message)
|
|
|
|
if mlen == 0 {
|
|
return nil, errors.Errorf("Refusing to create message without contents")
|
|
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
|
|
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
|
|
}
|
|
|
|
// @todo get user from context
|
|
var currentUserID uint64 = repository.Identity(svc.ctx)
|
|
|
|
in.UserID = currentUserID
|
|
|
|
return message, svc.db.Transaction(func() (err error) {
|
|
// Broadcast queue
|
|
var bq = types.MessageSet{}
|
|
|
|
if in.ReplyTo > 0 {
|
|
var original *types.Message
|
|
var replyTo = in.ReplyTo
|
|
|
|
for replyTo > 0 {
|
|
// Find original message
|
|
original, err = svc.message.FindMessageByID(in.ReplyTo)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
replyTo = original.ReplyTo
|
|
}
|
|
|
|
if !original.Type.IsRepliable() {
|
|
return errors.Errorf("Unable to reply on this message (type = %s)", original.Type)
|
|
}
|
|
|
|
// We do not want to have multi-level threads
|
|
// Take original's reply-to and use it
|
|
in.ReplyTo = original.ID
|
|
|
|
in.ChannelID = original.ChannelID
|
|
|
|
// Increment counter, on struct and in repostiry.
|
|
original.Replies++
|
|
if err = svc.message.IncReplyCount(original.ID); err != nil {
|
|
return
|
|
}
|
|
|
|
// Broadcast updated original
|
|
bq = append(bq, original)
|
|
}
|
|
|
|
if in.ChannelID == 0 {
|
|
return errors.New("ChannelID missing")
|
|
}
|
|
|
|
// @todo [SECURITY] verify if current user can access & write to this channel
|
|
|
|
if message, err = svc.message.CreateMessage(in); err != nil {
|
|
return
|
|
}
|
|
|
|
if err = svc.updateMentions(message.ID, svc.extractMentions(message)); err != nil {
|
|
return
|
|
}
|
|
|
|
if err = svc.unreads.Inc(message.ChannelID, message.ReplyTo, message.UserID); err != nil {
|
|
return
|
|
}
|
|
|
|
return svc.sendEvent(append(bq, message)...)
|
|
})
|
|
}
|
|
|
|
func (svc *message) Update(in *types.Message) (message *types.Message, err error) {
|
|
if in == nil {
|
|
in = &types.Message{}
|
|
}
|
|
|
|
in.Message = strings.TrimSpace(in.Message)
|
|
var mlen = len(in.Message)
|
|
|
|
if mlen == 0 {
|
|
return nil, errors.Errorf("Refusing to update message without contents")
|
|
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
|
|
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
|
|
}
|
|
|
|
// @todo get user from context
|
|
var currentUserID uint64 = repository.Identity(svc.ctx)
|
|
|
|
// @todo verify if current user can access & write to this channel
|
|
_ = currentUserID
|
|
|
|
return message, svc.db.Transaction(func() (err error) {
|
|
message, err = svc.message.FindMessageByID(in.ID)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Could not load message for editing")
|
|
}
|
|
|
|
if message.Message == in.Message {
|
|
// Nothing changed
|
|
return nil
|
|
}
|
|
|
|
if message.UserID != currentUserID {
|
|
return errors.New("Not an owner")
|
|
}
|
|
|
|
// Allow message content to be changed
|
|
message.Message = in.Message
|
|
|
|
if message, err = svc.message.UpdateMessage(message); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = svc.updateMentions(message.ID, svc.extractMentions(message)); err != nil {
|
|
return
|
|
}
|
|
|
|
return svc.sendEvent(message)
|
|
})
|
|
}
|
|
|
|
func (svc *message) Delete(ID uint64) error {
|
|
// @todo get user from context
|
|
var currentUserID uint64 = repository.Identity(svc.ctx)
|
|
|
|
// @todo verify if current user can access & write to this channel
|
|
_ = currentUserID
|
|
|
|
// @todo load current message
|
|
// @todo verify ownership
|
|
|
|
return svc.db.Transaction(func() (err error) {
|
|
// Broadcast queue
|
|
var bq = types.MessageSet{}
|
|
var deletedMsg, original *types.Message
|
|
|
|
deletedMsg, err = svc.message.FindMessageByID(ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if deletedMsg.ReplyTo > 0 {
|
|
original, err = svc.message.FindMessageByID(deletedMsg.ReplyTo)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// This is a reply to another message, decrease reply counter on the original, on struct and in the
|
|
// repository
|
|
if original.Replies > 0 {
|
|
original.Replies--
|
|
}
|
|
|
|
if err = svc.message.DecReplyCount(original.ID); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Broadcast updated original
|
|
bq = append(bq, original)
|
|
}
|
|
|
|
if err = svc.message.DeleteMessageByID(ID); err != nil {
|
|
return
|
|
}
|
|
|
|
if err = svc.unreads.Dec(deletedMsg.ChannelID, deletedMsg.ReplyTo, deletedMsg.UserID); err != nil {
|
|
return err
|
|
} else {
|
|
// Set deletedAt timestamp so that our clients can react properly...
|
|
deletedMsg.DeletedAt = timeNowPtr()
|
|
}
|
|
|
|
if err = svc.updateMentions(ID, nil); err != nil {
|
|
return
|
|
}
|
|
|
|
return svc.sendEvent(append(bq, deletedMsg)...)
|
|
})
|
|
}
|
|
|
|
// Pin message to the channel
|
|
func (svc *message) MarkAsUnread(messageID uint64) error {
|
|
var currentUserID uint64 = repository.Identity(svc.ctx)
|
|
|
|
return svc.db.Transaction(func() (err error) {
|
|
// Broadcast queue
|
|
var message *types.Message
|
|
var count uint32
|
|
|
|
message, err = svc.message.FindMessageByID(messageID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
count, err = svc.message.CountFromMessageID(message.ChannelID, message.ReplyTo, message.ID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Inc counter so that we take
|
|
// this message into account
|
|
count++
|
|
|
|
if message.ReplyTo > 0 {
|
|
return svc.unreads.Record(currentUserID, message.ChannelID, message.ReplyTo, messageID, count)
|
|
} else {
|
|
return svc.unreads.Record(currentUserID, message.ChannelID, 0, messageID, count)
|
|
}
|
|
})
|
|
}
|
|
|
|
// React on a message with an emoji
|
|
func (svc *message) React(messageID uint64, reaction string) error {
|
|
return svc.flag(messageID, reaction, false)
|
|
}
|
|
|
|
// Remove reaction on a message
|
|
func (svc *message) RemoveReaction(messageID uint64, reaction string) error {
|
|
return svc.flag(messageID, reaction, true)
|
|
}
|
|
|
|
// Pin message to the channel
|
|
func (svc *message) Pin(messageID uint64) error {
|
|
return svc.flag(messageID, types.MessageFlagPinnedToChannel, false)
|
|
}
|
|
|
|
// Remove pin from message
|
|
func (svc *message) RemovePin(messageID uint64) error {
|
|
return svc.flag(messageID, types.MessageFlagPinnedToChannel, true)
|
|
}
|
|
|
|
// Bookmark message (private)
|
|
func (svc *message) Bookmark(messageID uint64) error {
|
|
return svc.flag(messageID, types.MessageFlagBookmarkedMessage, false)
|
|
}
|
|
|
|
// Remove bookmark message (private)
|
|
func (svc *message) RemoveBookmark(messageID uint64) error {
|
|
return svc.flag(messageID, types.MessageFlagBookmarkedMessage, true)
|
|
}
|
|
|
|
// React on a message with an emoji
|
|
func (svc *message) flag(messageID uint64, flag string, remove bool) error {
|
|
// @todo get user from context
|
|
var currentUserID uint64 = repository.Identity(svc.ctx)
|
|
|
|
// @todo verify if current user can access & write to this channel
|
|
_ = currentUserID
|
|
|
|
if strings.TrimSpace(flag) == "" {
|
|
// Sanitize
|
|
flag = types.MessageFlagPinnedToChannel
|
|
}
|
|
|
|
// @todo validate flags beyond empty string
|
|
|
|
err := svc.db.Transaction(func() (err error) {
|
|
var flagOwnerId = currentUserID
|
|
var f *types.MessageFlag
|
|
|
|
// @todo [SECURITY] verify if current user can access & write to this channel
|
|
|
|
if flag == types.MessageFlagPinnedToChannel {
|
|
// It does not matter how is the owner of the pin,
|
|
flagOwnerId = 0
|
|
}
|
|
|
|
f, err = svc.mflag.FindByFlag(messageID, flagOwnerId, flag)
|
|
if f.ID == 0 && remove {
|
|
// Skip removing, flag does not exists
|
|
return nil
|
|
} else if f.ID > 0 && !remove {
|
|
// Skip adding, flag already exists
|
|
return nil
|
|
} else if err != nil && err != repository.ErrMessageFlagNotFound {
|
|
// Other errors, exit
|
|
return
|
|
}
|
|
|
|
// Check message
|
|
var msg *types.Message
|
|
msg, err = svc.message.FindMessageByID(messageID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if remove {
|
|
err = svc.mflag.DeleteByID(f.ID)
|
|
f.DeletedAt = timeNowPtr()
|
|
} else {
|
|
f, err = svc.mflag.Create(&types.MessageFlag{
|
|
UserID: currentUserID,
|
|
ChannelID: msg.ChannelID,
|
|
MessageID: msg.ID,
|
|
Flag: flag,
|
|
})
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
svc.sendFlagEvent(f)
|
|
|
|
return
|
|
})
|
|
|
|
return errors.Wrap(err, "Can not flag/un-flag message")
|
|
}
|
|
|
|
func (svc *message) preload(mm types.MessageSet) (err error) {
|
|
if err = svc.preloadUsers(mm); err != nil {
|
|
return
|
|
}
|
|
|
|
if err = svc.preloadAttachments(mm); err != nil {
|
|
return
|
|
}
|
|
|
|
if err = svc.preloadFlags(mm); err != nil {
|
|
return
|
|
}
|
|
|
|
if err = svc.preloadMentions(mm); err != nil {
|
|
return
|
|
}
|
|
|
|
if err = svc.message.PrefillThreadParticipants(mm); err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Preload for all messages
|
|
func (svc *message) preloadUsers(mm types.MessageSet) (err error) {
|
|
var uu systemTypes.UserSet
|
|
|
|
for _, msg := range mm {
|
|
if msg.User != nil || msg.UserID == 0 {
|
|
continue
|
|
}
|
|
|
|
if msg.User = uu.FindByID(msg.UserID); msg.User != nil {
|
|
continue
|
|
}
|
|
|
|
if msg.User, _ = svc.usr.FindByID(msg.UserID); msg.User != nil {
|
|
// @todo fix this handler errors (ignore user-not-found, return others)
|
|
uu = append(uu, msg.User)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Preload for all messages
|
|
func (svc *message) preloadFlags(mm types.MessageSet) (err error) {
|
|
var ff types.MessageFlagSet
|
|
|
|
ff, err = svc.mflag.FindByMessageIDs(mm.IDs()...)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return ff.Walk(func(flag *types.MessageFlag) error {
|
|
mm.FindByID(flag.MessageID).Flags = append(mm.FindByID(flag.MessageID).Flags, flag)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// Preload for all messages
|
|
func (svc *message) preloadMentions(mm types.MessageSet) (err error) {
|
|
var mentions types.MentionSet
|
|
|
|
mentions, err = svc.mentions.FindByMessageIDs(mm.IDs()...)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return mm.Walk(func(m *types.Message) error {
|
|
m.Mentions = mentions.FindByMessageID(m.ID)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (svc *message) preloadAttachments(mm types.MessageSet) (err error) {
|
|
var (
|
|
ids []uint64
|
|
aa types.MessageAttachmentSet
|
|
)
|
|
|
|
err = mm.Walk(func(m *types.Message) error {
|
|
if m.Type == types.MessageTypeAttachment || m.Type == types.MessageTypeInlineImage {
|
|
ids = append(ids, m.ID)
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if aa, err = svc.attachment.FindAttachmentByMessageID(ids...); err != nil {
|
|
return
|
|
} else {
|
|
return aa.Walk(func(a *types.MessageAttachment) error {
|
|
if a.MessageID > 0 {
|
|
if m := mm.FindByID(a.MessageID); m != nil {
|
|
m.Attachment = &a.Attachment
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
}
|
|
|
|
// Sends message to event loop
|
|
func (svc *message) sendEvent(mm ...*types.Message) (err error) {
|
|
if err = svc.preload(mm); err != nil {
|
|
return
|
|
}
|
|
|
|
for _, msg := range mm {
|
|
if msg.User == nil {
|
|
// @todo fix this handler errors (ignore user-not-found, return others)
|
|
msg.User, _ = svc.usr.FindByID(msg.UserID)
|
|
}
|
|
|
|
if err = svc.evl.Message(msg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Sends message to event loop
|
|
func (svc *message) sendFlagEvent(ff ...*types.MessageFlag) (err error) {
|
|
for _, f := range ff {
|
|
if err = svc.evl.MessageFlag(f); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (svc *message) extractMentions(m *types.Message) (mm types.MentionSet) {
|
|
const reSubID = 2
|
|
mm = types.MentionSet{}
|
|
|
|
match := mentionsFinder.FindAllStringSubmatch(m.Message, -1)
|
|
|
|
// Prepopulated with all we know from message
|
|
tpl := types.Mention{
|
|
ChannelID: m.ChannelID,
|
|
MessageID: m.ID,
|
|
MentionedByID: m.UserID,
|
|
}
|
|
|
|
for m := 0; m < len(match); m++ {
|
|
uid := payload.ParseUInt64(match[m][reSubID])
|
|
if len(mm.FindByUserID(uid)) == 0 {
|
|
// Copy template & assign user id
|
|
mnt := tpl
|
|
mnt.UserID = uid
|
|
mm = append(mm, &mnt)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error {
|
|
if existing, err := svc.mentions.FindByMessageIDs(messageID); err != nil {
|
|
return errors.Wrap(err, "Could not update mentions")
|
|
} else if len(mm) > 0 {
|
|
add, _, del := existing.Diff(mm)
|
|
|
|
err = add.Walk(func(m *types.Mention) error {
|
|
m, err = svc.mentions.Create(m)
|
|
return err
|
|
})
|
|
|
|
if err != nil {
|
|
return errors.Wrap(err, "Could not create mentions")
|
|
}
|
|
|
|
err = del.Walk(func(m *types.Mention) error {
|
|
return svc.mentions.DeleteByID(m.ID)
|
|
})
|
|
|
|
if err != nil {
|
|
return errors.Wrap(err, "Could not delete mentions")
|
|
}
|
|
} else {
|
|
return svc.mentions.DeleteByMessageID(messageID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
var _ MessageService = &message{}
|