3
0
Files
corteza/sam/service/message.go
2018-11-21 07:02:57 +01:00

636 lines
15 KiB
Go

package service
import (
"context"
"regexp"
"strings"
"github.com/pkg/errors"
"github.com/crusttech/crust/internal/payload"
"github.com/crusttech/crust/sam/repository"
"github.com/crusttech/crust/sam/types"
systemService "github.com/crusttech/crust/system/service"
systemTypes "github.com/crusttech/crust/system/types"
)
type (
message struct {
db db
ctx context.Context
attachment repository.AttachmentRepository
channel repository.ChannelRepository
cmember repository.ChannelMemberRepository
unreads repository.UnreadRepository
message repository.MessageRepository
mflag repository.MessageFlagRepository
mentions repository.MentionRepository
usr systemService.UserService
evl EventService
}
MessageService interface {
With(ctx context.Context) MessageService
Find(filter *types.MessageFilter) (types.MessageSet, error)
FindThreads(filter *types.MessageFilter) (types.MessageSet, error)
Create(messages *types.Message) (*types.Message, error)
Update(messages *types.Message) (*types.Message, error)
React(messageID uint64, reaction string) error
RemoveReaction(messageID uint64, reaction string) error
MarkAsUnread(messageID uint64) error
Pin(messageID uint64) error
RemovePin(messageID uint64) error
Bookmark(messageID uint64) error
RemoveBookmark(messageID uint64) error
Delete(ID uint64) error
}
)
const (
settingsMessageBodyLength = 0
mentionRE = `<([@#])(\d+)((?:\s)([^>]+))?>`
)
var (
mentionsFinder = regexp.MustCompile(mentionRE)
)
func Message() MessageService {
return &message{
usr: systemService.DefaultUser,
evl: DefaultEvent,
}
}
func (svc *message) With(ctx context.Context) MessageService {
db := repository.DB(ctx)
return &message{
db: db,
ctx: ctx,
usr: svc.usr.With(ctx),
evl: svc.evl.With(ctx),
attachment: repository.Attachment(ctx, db),
channel: repository.Channel(ctx, db),
cmember: repository.ChannelMember(ctx, db),
unreads: repository.ChannelView(ctx, db),
message: repository.Message(ctx, db),
mflag: repository.MessageFlag(ctx, db),
mentions: repository.Mention(ctx, db),
}
}
func (svc *message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) {
// @todo get user from context
filter.CurrentUserID = repository.Identity(svc.ctx)
// @todo verify if current user can access & read from this channel
_ = filter.ChannelID
mm, err = svc.message.FindMessages(filter)
if err != nil {
return nil, err
}
return mm, svc.preload(mm)
}
func (svc *message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) {
// @todo get user from context
filter.CurrentUserID = repository.Identity(svc.ctx)
// @todo verify if current user can access & read from this channel
_ = filter.ChannelID
mm, err = svc.message.FindThreads(filter)
if err != nil {
return nil, err
}
return mm, svc.preload(mm)
}
func (svc *message) Create(in *types.Message) (message *types.Message, err error) {
if in == nil {
in = &types.Message{}
}
in.Message = strings.TrimSpace(in.Message)
var mlen = len(in.Message)
if mlen == 0 {
return nil, errors.Errorf("Refusing to create message without contents")
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
}
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
in.UserID = currentUserID
return message, svc.db.Transaction(func() (err error) {
// Broadcast queue
var bq = types.MessageSet{}
if in.ReplyTo > 0 {
var original *types.Message
var replyTo = in.ReplyTo
for replyTo > 0 {
// Find original message
original, err = svc.message.FindMessageByID(in.ReplyTo)
if err != nil {
return
}
replyTo = original.ReplyTo
}
if !original.Type.IsRepliable() {
return errors.Errorf("Unable to reply on this message (type = %s)", original.Type)
}
// We do not want to have multi-level threads
// Take original's reply-to and use it
in.ReplyTo = original.ID
in.ChannelID = original.ChannelID
// Increment counter, on struct and in repostiry.
original.Replies++
if err = svc.message.IncReplyCount(original.ID); err != nil {
return
}
// Broadcast updated original
bq = append(bq, original)
}
if in.ChannelID == 0 {
return errors.New("ChannelID missing")
}
// @todo [SECURITY] verify if current user can access & write to this channel
if message, err = svc.message.CreateMessage(in); err != nil {
return
}
if err = svc.updateMentions(message.ID, svc.extractMentions(message)); err != nil {
return
}
if err = svc.unreads.Inc(message.ChannelID, message.ReplyTo, message.UserID); err != nil {
return
}
return svc.sendEvent(append(bq, message)...)
})
}
func (svc *message) Update(in *types.Message) (message *types.Message, err error) {
if in == nil {
in = &types.Message{}
}
in.Message = strings.TrimSpace(in.Message)
var mlen = len(in.Message)
if mlen == 0 {
return nil, errors.Errorf("Refusing to update message without contents")
} else if settingsMessageBodyLength > 0 && mlen > settingsMessageBodyLength {
return nil, errors.Errorf("Message length (%d characters) too long (max: %d)", mlen, settingsMessageBodyLength)
}
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
// @todo verify if current user can access & write to this channel
_ = currentUserID
return message, svc.db.Transaction(func() (err error) {
message, err = svc.message.FindMessageByID(in.ID)
if err != nil {
return errors.Wrap(err, "Could not load message for editing")
}
if message.Message == in.Message {
// Nothing changed
return nil
}
if message.UserID != currentUserID {
return errors.New("Not an owner")
}
// Allow message content to be changed
message.Message = in.Message
if message, err = svc.message.UpdateMessage(message); err != nil {
return err
}
if err = svc.updateMentions(message.ID, svc.extractMentions(message)); err != nil {
return
}
return svc.sendEvent(message)
})
}
func (svc *message) Delete(ID uint64) error {
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
// @todo verify if current user can access & write to this channel
_ = currentUserID
// @todo load current message
// @todo verify ownership
return svc.db.Transaction(func() (err error) {
// Broadcast queue
var bq = types.MessageSet{}
var deletedMsg, original *types.Message
deletedMsg, err = svc.message.FindMessageByID(ID)
if err != nil {
return err
}
if deletedMsg.ReplyTo > 0 {
original, err = svc.message.FindMessageByID(deletedMsg.ReplyTo)
if err != nil {
return err
}
// This is a reply to another message, decrease reply counter on the original, on struct and in the
// repository
if original.Replies > 0 {
original.Replies--
}
if err = svc.message.DecReplyCount(original.ID); err != nil {
return err
}
// Broadcast updated original
bq = append(bq, original)
}
if err = svc.message.DeleteMessageByID(ID); err != nil {
return
}
if err = svc.unreads.Dec(deletedMsg.ChannelID, deletedMsg.ReplyTo, deletedMsg.UserID); err != nil {
return err
} else {
// Set deletedAt timestamp so that our clients can react properly...
deletedMsg.DeletedAt = timeNowPtr()
}
if err = svc.updateMentions(ID, nil); err != nil {
return
}
return svc.sendEvent(append(bq, deletedMsg)...)
})
}
// Pin message to the channel
func (svc *message) MarkAsUnread(messageID uint64) error {
var currentUserID uint64 = repository.Identity(svc.ctx)
return svc.db.Transaction(func() (err error) {
// Broadcast queue
var message *types.Message
var count uint32
message, err = svc.message.FindMessageByID(messageID)
if err != nil {
return err
}
count, err = svc.message.CountFromMessageID(message.ChannelID, message.ReplyTo, message.ID)
if err != nil {
return
}
// Inc counter so that we take
// this message into account
count++
if message.ReplyTo > 0 {
return svc.unreads.Record(currentUserID, message.ChannelID, message.ReplyTo, messageID, count)
} else {
return svc.unreads.Record(currentUserID, message.ChannelID, 0, messageID, count)
}
})
}
// React on a message with an emoji
func (svc *message) React(messageID uint64, reaction string) error {
return svc.flag(messageID, reaction, false)
}
// Remove reaction on a message
func (svc *message) RemoveReaction(messageID uint64, reaction string) error {
return svc.flag(messageID, reaction, true)
}
// Pin message to the channel
func (svc *message) Pin(messageID uint64) error {
return svc.flag(messageID, types.MessageFlagPinnedToChannel, false)
}
// Remove pin from message
func (svc *message) RemovePin(messageID uint64) error {
return svc.flag(messageID, types.MessageFlagPinnedToChannel, true)
}
// Bookmark message (private)
func (svc *message) Bookmark(messageID uint64) error {
return svc.flag(messageID, types.MessageFlagBookmarkedMessage, false)
}
// Remove bookmark message (private)
func (svc *message) RemoveBookmark(messageID uint64) error {
return svc.flag(messageID, types.MessageFlagBookmarkedMessage, true)
}
// React on a message with an emoji
func (svc *message) flag(messageID uint64, flag string, remove bool) error {
// @todo get user from context
var currentUserID uint64 = repository.Identity(svc.ctx)
// @todo verify if current user can access & write to this channel
_ = currentUserID
if strings.TrimSpace(flag) == "" {
// Sanitize
flag = types.MessageFlagPinnedToChannel
}
// @todo validate flags beyond empty string
err := svc.db.Transaction(func() (err error) {
var flagOwnerId = currentUserID
var f *types.MessageFlag
// @todo [SECURITY] verify if current user can access & write to this channel
if flag == types.MessageFlagPinnedToChannel {
// It does not matter how is the owner of the pin,
flagOwnerId = 0
}
f, err = svc.mflag.FindByFlag(messageID, flagOwnerId, flag)
if f.ID == 0 && remove {
// Skip removing, flag does not exists
return nil
} else if f.ID > 0 && !remove {
// Skip adding, flag already exists
return nil
} else if err != nil && err != repository.ErrMessageFlagNotFound {
// Other errors, exit
return
}
// Check message
var msg *types.Message
msg, err = svc.message.FindMessageByID(messageID)
if err != nil {
return
}
if remove {
err = svc.mflag.DeleteByID(f.ID)
f.DeletedAt = timeNowPtr()
} else {
f, err = svc.mflag.Create(&types.MessageFlag{
UserID: currentUserID,
ChannelID: msg.ChannelID,
MessageID: msg.ID,
Flag: flag,
})
}
if err != nil {
return err
}
svc.sendFlagEvent(f)
return
})
return errors.Wrap(err, "Can not flag/un-flag message")
}
func (svc *message) preload(mm types.MessageSet) (err error) {
if err = svc.preloadUsers(mm); err != nil {
return
}
if err = svc.preloadAttachments(mm); err != nil {
return
}
if err = svc.preloadFlags(mm); err != nil {
return
}
if err = svc.preloadMentions(mm); err != nil {
return
}
if err = svc.message.PrefillThreadParticipants(mm); err != nil {
return
}
return
}
// Preload for all messages
func (svc *message) preloadUsers(mm types.MessageSet) (err error) {
var uu systemTypes.UserSet
for _, msg := range mm {
if msg.User != nil || msg.UserID == 0 {
continue
}
if msg.User = uu.FindByID(msg.UserID); msg.User != nil {
continue
}
if msg.User, _ = svc.usr.FindByID(msg.UserID); msg.User != nil {
// @todo fix this handler errors (ignore user-not-found, return others)
uu = append(uu, msg.User)
}
}
return
}
// Preload for all messages
func (svc *message) preloadFlags(mm types.MessageSet) (err error) {
var ff types.MessageFlagSet
ff, err = svc.mflag.FindByMessageIDs(mm.IDs()...)
if err != nil {
return
}
return ff.Walk(func(flag *types.MessageFlag) error {
mm.FindByID(flag.MessageID).Flags = append(mm.FindByID(flag.MessageID).Flags, flag)
return nil
})
}
// Preload for all messages
func (svc *message) preloadMentions(mm types.MessageSet) (err error) {
var mentions types.MentionSet
mentions, err = svc.mentions.FindByMessageIDs(mm.IDs()...)
if err != nil {
return
}
return mm.Walk(func(m *types.Message) error {
m.Mentions = mentions.FindByMessageID(m.ID)
return nil
})
}
func (svc *message) preloadAttachments(mm types.MessageSet) (err error) {
var (
ids []uint64
aa types.MessageAttachmentSet
)
err = mm.Walk(func(m *types.Message) error {
if m.Type == types.MessageTypeAttachment || m.Type == types.MessageTypeInlineImage {
ids = append(ids, m.ID)
}
return nil
})
if err != nil {
return
}
if aa, err = svc.attachment.FindAttachmentByMessageID(ids...); err != nil {
return
} else {
return aa.Walk(func(a *types.MessageAttachment) error {
if a.MessageID > 0 {
if m := mm.FindByID(a.MessageID); m != nil {
m.Attachment = &a.Attachment
}
}
return nil
})
}
}
// Sends message to event loop
func (svc *message) sendEvent(mm ...*types.Message) (err error) {
if err = svc.preload(mm); err != nil {
return
}
for _, msg := range mm {
if msg.User == nil {
// @todo fix this handler errors (ignore user-not-found, return others)
msg.User, _ = svc.usr.FindByID(msg.UserID)
}
if err = svc.evl.Message(msg); err != nil {
return
}
}
return
}
// Sends message to event loop
func (svc *message) sendFlagEvent(ff ...*types.MessageFlag) (err error) {
for _, f := range ff {
if err = svc.evl.MessageFlag(f); err != nil {
return
}
}
return
}
func (svc *message) extractMentions(m *types.Message) (mm types.MentionSet) {
const reSubID = 2
mm = types.MentionSet{}
match := mentionsFinder.FindAllStringSubmatch(m.Message, -1)
// Prepopulated with all we know from message
tpl := types.Mention{
ChannelID: m.ChannelID,
MessageID: m.ID,
MentionedByID: m.UserID,
}
for m := 0; m < len(match); m++ {
uid := payload.ParseUInt64(match[m][reSubID])
if len(mm.FindByUserID(uid)) == 0 {
// Copy template & assign user id
mnt := tpl
mnt.UserID = uid
mm = append(mm, &mnt)
}
}
return
}
func (svc *message) updateMentions(messageID uint64, mm types.MentionSet) error {
if existing, err := svc.mentions.FindByMessageIDs(messageID); err != nil {
return errors.Wrap(err, "Could not update mentions")
} else if len(mm) > 0 {
add, _, del := existing.Diff(mm)
err = add.Walk(func(m *types.Mention) error {
m, err = svc.mentions.Create(m)
return err
})
if err != nil {
return errors.Wrap(err, "Could not create mentions")
}
err = del.Walk(func(m *types.Mention) error {
return svc.mentions.DeleteByID(m.ID)
})
if err != nil {
return errors.Wrap(err, "Could not delete mentions")
}
} else {
return svc.mentions.DeleteByMessageID(messageID)
}
return nil
}
var _ MessageService = &message{}