diff --git a/internal/payload/outgoing.go b/internal/payload/outgoing.go index 2218d8fdc..070233e67 100644 --- a/internal/payload/outgoing.go +++ b/internal/payload/outgoing.go @@ -40,7 +40,7 @@ func Message(ctx context.Context, msg *messagingTypes.Message) *outgoing.Message ReplyTo: msg.ReplyTo, Replies: msg.Replies, RepliesFrom: Uint64stoa(msg.RepliesFrom), - Unread: Unread(msg.Unread), + Unread: MessageUnread(msg.Unread), Attachment: Attachment(msg.Attachment, currentUserID), Mentions: messageMentionSet(msg.Mentions), @@ -145,7 +145,7 @@ func Channel(ch *messagingTypes.Channel) *outgoing.Channel { Type: string(ch.Type), MembershipFlag: string(flag), Members: Uint64stoa(ch.Members), - Unread: Unread(ch.Unread), + Unread: ChannelUnread(ch.Unread), CanJoin: ch.CanJoin, CanPart: ch.CanPart, @@ -196,10 +196,37 @@ func Unread(v *messagingTypes.Unread) *outgoing.Unread { return nil } + return &outgoing.Unread{ + ChannelID: v.ChannelID, + ThreadID: v.ReplyTo, + LastMessageID: v.LastMessageID, + Count: v.Count, + ThreadCount: v.ThreadCount, + ThreadTotal: v.ThreadTotal, + } +} + +func ChannelUnread(v *messagingTypes.Unread) *outgoing.Unread { + if v == nil || (v.Count == 0 && v.ThreadCount == 0) { + return nil + } + + return &outgoing.Unread{ + LastMessageID: v.LastMessageID, + Count: v.Count, + ThreadCount: v.ThreadCount, + ThreadTotal: v.ThreadTotal, + } +} + +func MessageUnread(v *messagingTypes.Unread) *outgoing.Unread { + if v == nil || v.Count == 0 { + return nil + } + return &outgoing.Unread{ LastMessageID: v.LastMessageID, Count: v.Count, - InThreadCount: v.InThreadCount, } } diff --git a/internal/payload/outgoing/payload.go b/internal/payload/outgoing/payload.go index c556c6f47..0a93c8e23 100644 --- a/internal/payload/outgoing/payload.go +++ b/internal/payload/outgoing/payload.go @@ -19,6 +19,8 @@ type ( *Channel `json:"channel,omitempty"` *ChannelSet `json:"channels,omitempty"` + *Unread `json:"unread,omitempty"` + *ChannelMember `json:"channelMember,omitempty"` *ChannelMemberSet `json:"channelMembers,omitempty"` diff --git a/internal/payload/outgoing/unread.go b/internal/payload/outgoing/unread.go index a9d904041..c53a6ccbd 100644 --- a/internal/payload/outgoing/unread.go +++ b/internal/payload/outgoing/unread.go @@ -1,10 +1,20 @@ package outgoing +import "encoding/json" + type ( Unread struct { - // Channel to part (nil) for ALL channels + ChannelID uint64 `json:"channelID,string,omitempty"` + ThreadID uint64 `json:"threadID,string,omitempty"` + LastMessageID uint64 `json:"lastMessageID,string,omitempty"` Count uint32 `json:"count"` - InThreadCount uint32 `json:"tcount,omitempty"` + + ThreadCount uint32 `json:"threadCount"` + ThreadTotal uint32 `json:"threadTotal,omitempty"` } ) + +func (p *Unread) EncodeMessage() ([]byte, error) { + return json.Marshal(Payload{Unread: p}) +} diff --git a/messaging/internal/repository/repository.go b/messaging/internal/repository/repository.go index 9a224817c..9b6fc6495 100644 --- a/messaging/internal/repository/repository.go +++ b/messaging/internal/repository/repository.go @@ -4,6 +4,7 @@ import ( "context" "github.com/titpetric/factory" + "gopkg.in/Masterminds/squirrel.v1" "github.com/cortezaproject/corteza-server/internal/auth" ) @@ -45,3 +46,21 @@ func (r *repository) db() *factory.DB { } return DB(r.ctx) } + +// Fetches single row from table +func (r repository) fetchSet(set interface{}, q squirrel.SelectBuilder) (err error) { + var ( + sql string + args []interface{} + ) + + if sql, args, err = q.ToSql(); err != nil { + return + } + + if err = r.db().Select(set, sql, args...); err != nil { + return + } + + return +} diff --git a/messaging/internal/repository/unread.go b/messaging/internal/repository/unread.go index a7534ef1b..b28252501 100644 --- a/messaging/internal/repository/unread.go +++ b/messaging/internal/repository/unread.go @@ -3,8 +3,8 @@ package repository import ( "context" - "github.com/jmoiron/sqlx" "github.com/titpetric/factory" + "gopkg.in/Masterminds/squirrel.v1" "github.com/cortezaproject/corteza-server/messaging/types" ) @@ -14,7 +14,9 @@ type ( UnreadRepository interface { With(ctx context.Context, db *factory.DB) UnreadRepository - Find(filter *types.UnreadFilter) (types.UnreadSet, error) + Count(userID, channelID uint64, threadIDs ...uint64) (types.UnreadSet, error) + CountThreads(userID, channelID uint64) (types.UnreadSet, error) + Preset(channelID, threadID uint64, userIDs ...uint64) (err error) Record(userID, channelID, threadID, lastReadMessageID uint64, count uint32) error Inc(channelID, replyTo, userID uint64) error @@ -30,17 +32,6 @@ type ( ) const ( - // Fetching channel members of all channels a specific user has access to - sqlUnreadSelect = `SELECT rel_channel, rel_reply_to, rel_user, count, rel_last_message - FROM messaging_unread - WHERE count > 0 && rel_last_message > 0 ` - - // Fetching channel members of all channels a specific user has access to - sqlThreadUnreadSelect = `SELECT rel_channel, sum(count) as count - FROM messaging_unread - WHERE rel_user = ? AND rel_reply_to > 0 - GROUP BY rel_channel` - sqlUnreadIncCount = `UPDATE messaging_unread SET count = count + 1 WHERE rel_channel = ? AND rel_reply_to = ? AND rel_user <> ?` @@ -49,6 +40,8 @@ const ( SET count = count - 1 WHERE rel_channel = ? AND rel_reply_to = ? AND count > 0` + sqlResetCount = `REPLACE INTO messaging_unread (rel_channel, rel_reply_to, rel_user, count) VALUES (?, ?, ?, 0)` + sqlUnreadPresetChannel = `INSERT IGNORE INTO messaging_unread (rel_channel, rel_reply_to, rel_user) VALUES (?, ?, ?)` sqlUnreadPresetThreads = `INSERT IGNORE INTO messaging_unread (rel_channel, rel_reply_to, rel_user) SELECT rel_channel, id, ? @@ -62,6 +55,10 @@ func Unread(ctx context.Context, db *factory.DB) UnreadRepository { return (&unread{}).With(ctx, db) } +func (r unread) table() string { + return "messaging_unread" +} + // With context... func (r *unread) With(ctx context.Context, db *factory.DB) UnreadRepository { return &unread{ @@ -69,55 +66,84 @@ func (r *unread) With(ctx context.Context, db *factory.DB) UnreadRepository { } } -// Find unread info -func (r *unread) Find(filter *types.UnreadFilter) (uu types.UnreadSet, err error) { - params := make([]interface{}, 0) - sql := sqlUnreadSelect +// Count returns counts unread channel info +func (r *unread) Count(userID, channelID uint64, threadIDs ...uint64) (types.UnreadSet, error) { + var ( + uu = types.UnreadSet{} + q = squirrel. + Select(). + From(r.table()). + Columns( + "rel_channel", + "rel_last_message", + "rel_user", + "rel_reply_to", + "count") + ) - if filter != nil { - if filter.UserID > 0 { - // scope: only channel we have access to - sql += ` AND rel_user = ?` - params = append(params, filter.UserID) - } - - if filter.ChannelID > 0 { - // scope: only channel we have access to - sql += ` AND rel_channel = ?` - params = append(params, filter.ChannelID) - } - - if len(filter.ThreadIDs) > 0 { - sql += ` AND rel_reply_to IN (?)` - params = append(params, filter.ThreadIDs) - } else { - sql += ` AND rel_reply_to = 0` - } + if userID > 0 { + q = q.Where("rel_user = ?", userID) } - if sql, params, err = sqlx.In(sql, params...); err != nil { + if channelID > 0 { + q = q.Where("rel_channel = ?", channelID) + } + + if len(threadIDs) == 0 { + q = q.Where("rel_reply_to = 0") + } else { + q = q.Where(squirrel.Eq{"rel_reply_to": threadIDs}) + } + + return uu, r.fetchSet(&uu, q) +} + +// CountReplies counts unread thread info +func (r unread) CountThreads(userID, channelID uint64) (types.UnreadSet, error) { + type ( + u struct { + Rel_channel, Rel_user uint64 + Total, Count uint32 + } + ) + var ( + err error + + uu = types.UnreadSet{} + + temp = []*u{} + + q = squirrel. + Select(). + From(r.table()). + Columns( + "rel_channel", + "rel_user", + "sum(count) AS count", + "sum(CASE WHEN count > 0 THEN 1 ELSE 0 END) AS total"). + Where("rel_reply_to > 0 AND count > 0"). + GroupBy("rel_channel", "rel_user") + ) + + if userID > 0 { + q = q.Where("rel_user = ?", userID) + } + + if channelID > 0 { + q = q.Where("rel_channel = ?", channelID) + } + + err = r.fetchSet(&temp, q) + if err != nil { return nil, err - } else if err = r.db().Select(&uu, sql, params...); err != nil { - return nil, err - } else if len(filter.ThreadIDs) == 0 && filter.UserID > 0 { - // Check for unread thread messages + } - // We'll abuse Unread/UnreadSet - tt := types.UnreadSet{} - - err = r.db().Select(&tt, sqlThreadUnreadSelect, filter.UserID) - - _ = tt.Walk(func(t *types.Unread) error { - c := uu.FindByChannelId(t.ChannelID) - if c != nil { - c.InThreadCount = t.Count - } else { - // No un-reads in channel but we have them in threads (of that channel) - // swap values and append - t.InThreadCount, t.Count = t.Count, 0 - uu = append(uu, t) - } - return nil + for _, t := range temp { + uu = append(uu, &types.Unread{ + ChannelID: t.Rel_channel, + UserID: t.Rel_user, + ThreadCount: t.Count, + ThreadTotal: t.Total, }) } @@ -171,15 +197,27 @@ func (r *unread) Record(userID, channelID, threadID, lastReadMessageID uint64, c } // Inc increments unread message count on a channel/thread for all but one user -func (r *unread) Inc(channelID, threadID, userID uint64) error { - _, err := r.db().Exec(sqlUnreadIncCount, channelID, threadID, userID) - return err +func (r *unread) Inc(channelID, threadID, userID uint64) (err error) { + _, err = r.db().Exec(sqlUnreadIncCount, channelID, threadID, userID) + if err != nil { + return err + } + + return nil } // Dec decrements unread message count on a channel/thread for all but one user -func (r *unread) Dec(channelID, threadID, userID uint64) error { - _, err := r.db().Exec(sqlUnreadDecCount, channelID, threadID) - return err +func (r *unread) Dec(channelID, threadID, userID uint64) (err error) { + _, err = r.db().Exec(sqlUnreadDecCount, channelID, threadID) + if err != nil { + return err + } + _, err = r.db().Exec(sqlResetCount, channelID, threadID, userID) + if err != nil { + return err + } + + return nil } func (r *unread) CountOwned(userID uint64) (c int, err error) { diff --git a/messaging/internal/service/channel.go b/messaging/internal/service/channel.go index 38c8c73b6..3c1548c60 100644 --- a/messaging/internal/service/channel.go +++ b/messaging/internal/service/channel.go @@ -170,17 +170,41 @@ func (svc *channel) preloadMembers(cc types.ChannelSet) (err error) { return } +// preload channel unread info for a single user func (svc *channel) preloadUnreads(cc types.ChannelSet) error { var userID = auth.GetIdentityFromContext(svc.ctx).Identity() - if vv, err := svc.unread.Find(&types.UnreadFilter{UserID: userID}); err != nil { + if uu, err := svc.unread.Count(userID, 0); err != nil { return err } else { - return cc.Walk(func(ch *types.Channel) error { - ch.Unread = vv.FindByChannelId(ch.ID) + _ = cc.Walk(func(ch *types.Channel) error { + ch.Unread = uu.FindByChannelId(ch.ID) return nil }) } + + if uu, err := svc.unread.CountThreads(userID, 0); err != nil { + return err + } else { + _ = cc.Walk(func(ch *types.Channel) error { + var u = uu.FindByChannelId(ch.ID) + + if u == nil { + return nil + } + + if ch.Unread == nil { + ch.Unread = &types.Unread{} + } + + ch.Unread.ThreadCount = u.ThreadCount + ch.Unread.ThreadTotal = u.ThreadTotal + + return nil + }) + } + + return nil } // FindMembers loads all members (and full users) for a specific channel diff --git a/messaging/internal/service/event.go b/messaging/internal/service/event.go index 7f2049606..01e3d83c4 100644 --- a/messaging/internal/service/event.go +++ b/messaging/internal/service/event.go @@ -24,6 +24,7 @@ type ( Activity(a *types.Activity) error Message(m *types.Message) error MessageFlag(m *types.MessageFlag) error + UnreadCounters(uu types.UnreadSet) error Channel(m *types.Channel) error Join(userID, channelID uint64) error Part(userID, channelID uint64) error @@ -85,6 +86,12 @@ func (svc event) MessageFlag(f *types.MessageFlag) error { return nil } +func (svc event) UnreadCounters(uu types.UnreadSet) error { + return uu.Walk(func(u *types.Unread) error { + return svc.push(payload.Unread(u), types.EventQueueItemSubTypeUser, u.UserID) + }) +} + // Channel notifies subscribers about channel change // // If this is a public channel we notify everyone @@ -106,7 +113,7 @@ func (svc event) Join(userID, channelID uint64) (err error) { join := payload.ChannelJoin(channelID, userID) // Subscribe user to the channel - if err = svc.push(join, types.EventQueueItemSubTypeUser, 0); err != nil { + if err = svc.push(join, types.EventQueueItemSubTypeUser, userID); err != nil { return } @@ -131,7 +138,7 @@ func (svc event) Part(userID, channelID uint64) (err error) { } // Subscribe user to the channel - if err = svc.push(part, types.EventQueueItemSubTypeUser, 0); err != nil { + if err = svc.push(part, types.EventQueueItemSubTypeUser, userID); err != nil { return } diff --git a/messaging/internal/service/message.go b/messaging/internal/service/message.go index f54bfc60f..117bd69b0 100644 --- a/messaging/internal/service/message.go +++ b/messaging/internal/service/message.go @@ -189,7 +189,7 @@ func (svc message) filterMessagesByAccessibleChannels(mm types.MessageSet) types return mm } -func (svc message) Create(in *types.Message) (message *types.Message, err error) { +func (svc message) Create(in *types.Message) (m *types.Message, err error) { if in == nil { in = &types.Message{} } @@ -211,7 +211,7 @@ func (svc message) Create(in *types.Message) (message *types.Message, err error) in.UserID = auth.GetIdentityFromContext(svc.ctx).Identity() } - return message, svc.db.Transaction(func() (err error) { + return m, svc.db.Transaction(func() (err error) { // Broadcast queue var bq = types.MessageSet{} var ch *types.Channel @@ -279,19 +279,21 @@ func (svc message) Create(in *types.Message) (message *types.Message, err error) return ErrNoPermissions.withStack() } - if message, err = svc.message.Create(in); err != nil { + if m, err = svc.message.Create(in); err != nil { return } - if err = svc.updateMentions(message.ID, svc.extractMentions(message)); err != nil { + mentions := svc.extractMentions(m) + if err = svc.updateMentions(m.ID, mentions); err != nil { return } - if err = svc.unread.Inc(message.ChannelID, message.ReplyTo, message.UserID); err != nil { - return - } + svc.sendNotifications(m, mentions) - return svc.sendEvent(append(bq, message)...) + // Count unreads in the background and send updates to all users + svc.countUnreads(ch, m, 0) + + return svc.sendEvent(append(bq, m)...) }) } @@ -413,17 +415,16 @@ func (svc message) Delete(messageID uint64) error { return } - if err = svc.unread.Dec(deletedMsg.ChannelID, deletedMsg.ReplyTo, deletedMsg.UserID); err != nil { - return err - } else { - // Set deletedAt timestamp so that our clients can react properly... - deletedMsg.DeletedAt = timeNowPtr() - } + // Set deletedAt timestamp so that our clients can react properly... + deletedMsg.DeletedAt = timeNowPtr() if err = svc.updateMentions(messageID, nil); err != nil { return } + // Count unreads in the background and send updates to all users + svc.countUnreads(ch, deletedMsg, 0) + return svc.sendEvent(append(bq, deletedMsg)...) }) } @@ -435,7 +436,7 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui var ( currentUserID uint64 = repository.Identity(svc.ctx) count uint32 - tcount uint32 + threadCount uint32 err error ) @@ -463,9 +464,13 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui // This is request for channel, // count all thread unreads var uu types.UnreadSet - uu, err = svc.unread.Find(&types.UnreadFilter{UserID: currentUserID, ChannelID: channelID}) + uu, err = svc.unread.CountThreads(currentUserID, channelID) + if err != nil { + return err + } + if u := uu.FindByChannelId(channelID); u != nil { - tcount = u.InThreadCount + threadCount = u.ThreadCount } } @@ -497,10 +502,17 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui } err = svc.unread.Record(currentUserID, channelID, threadID, lastReadMessageID, count) - return errors.Wrap(err, "unable to record unread messages") + if err != nil { + return errors.Wrap(err, "unable to record unread messages") + } + + // Re-count unreads and send updates to this user + svc.countUnreads(ch, nil, currentUserID) + + return nil }) - return lastReadMessageID, count, tcount, errors.Wrap(err, "unable to mark as read") + return lastReadMessageID, count, threadCount, errors.Wrap(err, "unable to mark as read") } // React on a message with an emoji @@ -599,9 +611,7 @@ func (svc message) flag(messageID uint64, flag string, remove bool) error { return } - // @todo: log possible error - svc.sendFlagEvent(f) - + _ = svc.sendFlagEvent(f) return }) @@ -706,7 +716,7 @@ func (svc message) preloadUnreads(mm types.MessageSet) error { return nil } - if vv, err := svc.unread.Find(&types.UnreadFilter{UserID: userID, ThreadIDs: mm.IDs()}); err != nil { + if vv, err := svc.unread.Count(userID, 0, mm.IDs()...); err != nil { return err } else { return mm.Walk(func(m *types.Message) error { @@ -731,6 +741,90 @@ func (svc message) sendEvent(mm ...*types.Message) (err error) { return } +// Generates and sends notifications from the new message +// +// +func (svc message) sendNotifications(message *types.Message, mentions types.MentionSet) { + // @todo implementation +} + +// countUnreads orchestrates unread-related operations (inc/dec, (re)counting & sending events) +// +// 1. increases/decreases unread counters for channel or thread +// 2. collects all counters for channel or thread +// 3. sends unread events to subscribers +func (svc message) countUnreads(ch *types.Channel, m *types.Message, userID uint64) { + var ( + err error + uuBase, uuThreads, uuChannels types.UnreadSet + // mm types.ChannelMemberSet + threadIDs []uint64 + ) + + if m != nil { + if m.DeletedAt != nil { + // When deleting message, all existing counters are decreased! + if err = svc.unread.Dec(m.ChannelID, m.ReplyTo, m.UserID); err != nil { + svc.logger.With(zap.Error(err)).Info("could not decrement unread counter") + return + } + } else if m.UpdatedAt == nil { + // Reset user's counter and set current message ID as last read. + err = svc.unread.Record( + m.UserID, + m.ChannelID, + m.ReplyTo, + m.ID, + 0, + ) + + // When new message is created, update all existing counters + if err = svc.unread.Inc(m.ChannelID, m.ReplyTo, m.UserID); err != nil { + svc.logger.With(zap.Error(err)).Info("could not increment unread counter") + return + } + } + + if m.ReplyTo > 0 { + threadIDs = []uint64{m.ReplyTo} + } + } + + uuBase, err = svc.unread.Count(userID, ch.ID, threadIDs...) + if err != nil { + svc.logger.With(zap.Error(err)).Info("could not count unread messages") + return + } + + if len(threadIDs) > 0 { + // If base count was done for a thread, + // Do another count for channel + uuChannels, err = svc.unread.Count(userID, ch.ID) + if err != nil { + svc.logger.With(zap.Error(err)).Info("could not count unread messages") + return + } + + uuBase = uuBase.Merge(uuChannels) + + // Now recount all threads for this channel + uuThreads, err = svc.unread.CountThreads(userID, ch.ID) + if err != nil { + svc.logger.With(zap.Error(err)).Info("could not count unread messages") + return + } + + uuBase = uuBase.Merge(uuThreads) + } + + // This is a reply, make sure we fetch the new stats about unread replies and push them to users + err = svc.event.UnreadCounters(uuBase) + if err != nil { + svc.logger.With(zap.Error(err)).Info("could not send unread count event") + return + } +} + // Sends message to event loop func (svc message) sendFlagEvent(ff ...*types.MessageFlag) (err error) { for _, f := range ff { diff --git a/messaging/rest/message.go b/messaging/rest/message.go index 6436dbf64..f04c95746 100644 --- a/messaging/rest/message.go +++ b/messaging/rest/message.go @@ -68,7 +68,7 @@ func (ctrl *Message) MarkAsRead(ctx context.Context, r *request.MessageMarkAsRea return outgoing.Unread{ LastMessageID: messageID, Count: count, - InThreadCount: tcount, + ThreadCount: tcount, }, err } diff --git a/messaging/types/unread.go b/messaging/types/unread.go index 162bbc167..4bfdae1f7 100644 --- a/messaging/types/unread.go +++ b/messaging/types/unread.go @@ -7,13 +7,38 @@ type ( UserID uint64 `db:"rel_user"` LastMessageID uint64 `db:"rel_last_message"` - Count uint32 `db:"count"` - InThreadCount uint32 `db:"-"` - } - - UnreadFilter struct { - UserID uint64 - ChannelID uint64 - ThreadIDs []uint64 + Count uint32 `db:"count"` + ThreadCount uint32 `db:"-"` + ThreadTotal uint32 `db:"-"` } ) + +func (uu UnreadSet) Merge(in UnreadSet) UnreadSet { + var ( + out = append(UnreadSet{}, uu...) + olen = len(out) + ) + +inSet: + for _, i := range in { + for o := 0; o < olen; o++ { + if out[o].UserID == i.UserID && out[o].ChannelID == i.ChannelID && out[o].ReplyTo == i.ReplyTo { + if i.Count > 0 { + out[o].Count = i.Count + } + if i.ThreadCount > 0 { + out[o].ThreadCount = i.ThreadCount + } + if i.ThreadTotal > 0 { + out[o].ThreadTotal = i.ThreadTotal + } + + continue inSet + } + } + + out = append(out, i) + } + + return out +} diff --git a/messaging/types/unread_test.go b/messaging/types/unread_test.go new file mode 100644 index 000000000..ef00d29de --- /dev/null +++ b/messaging/types/unread_test.go @@ -0,0 +1,44 @@ +package types + +import ( + "reflect" + "testing" +) + +func TestUnreadSet_Merge(t *testing.T) { + type args struct { + in UnreadSet + } + tests := []struct { + name string + uu UnreadSet + args args + want UnreadSet + }{ + { + name: "simple", + uu: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}}, + args: args{in: UnreadSet{&Unread{ChannelID: 1, UserID: 1, ThreadCount: 3, ThreadTotal: 4}}}, + want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2, ThreadCount: 3, ThreadTotal: 4}}, + }, + { + name: "empty base", + uu: UnreadSet{}, + args: args{in: UnreadSet{&Unread{ChannelID: 1, UserID: 1, ThreadCount: 3, ThreadTotal: 4}}}, + want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 0, ThreadCount: 3, ThreadTotal: 4}}, + }, + { + name: "emmpt input", + uu: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}}, + args: args{in: nil}, + want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.uu.Merge(tt.args.in); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Merge() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/messaging/websocket/event_queue.go b/messaging/websocket/event_queue.go index 46bd9d957..94716a4b3 100644 --- a/messaging/websocket/event_queue.go +++ b/messaging/websocket/event_queue.go @@ -3,6 +3,7 @@ package websocket import ( "context" "encoding/json" + "errors" sentry "github.com/getsentry/sentry-go" "github.com/titpetric/factory" @@ -57,6 +58,10 @@ func (eq *eventQueue) store(ctx context.Context, qp repository.EventsRepository) } func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepository, store eventQueueWalker) error { + var ( + userID uint64 + ) + for { item, err := qp.Pull(ctx) if err != nil { @@ -64,6 +69,11 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo } if item.SubType == types.EventQueueItemSubTypeUser { + userID = payload.ParseUInt64(item.Subscriber) + if userID == 0 { + return errors.New("subscriber could not be parsed as uint64") + } + p := &outgoing.Payload{} if err := json.Unmarshal(item.Payload, p); err != nil { @@ -76,7 +86,7 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo // This store.Walk handler does not send to subscribed sessions but // subscribes all sessions that belong to the same user store.Walk(func(s *Session) { - if payload.Uint64toa(s.user.Identity()) == p.ChannelJoin.UserID { + if s.user.Identity() == userID { s.subs.Add(p.ChannelJoin.ID) } }) @@ -86,25 +96,28 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo // This store.Walk handler does not send to subscribed sessions but // subscribes all sessions that belong to the same user store.Walk(func(s *Session) { - if payload.Uint64toa(s.user.Identity()) == p.ChannelPart.UserID { + if s.user.Identity() == userID { s.subs.Delete(p.ChannelPart.ID) } }) } else { - // No other payload types we can handle at the moment. - return nil + store.Walk(func(s *Session) { + if s.user.Identity() == userID { + _ = s.sendBytes(item.Payload) + } + }) } } else if item.Subscriber == "" { // Distribute payload to all connected sessions store.Walk(func(s *Session) { - s.sendBytes(item.Payload) + _ = s.sendBytes(item.Payload) }) } else { // Distribute payload to specific subscribers store.Walk(func(s *Session) { if s.subs.Get(item.Subscriber) != nil { - s.sendBytes(item.Payload) + _ = s.sendBytes(item.Payload) } }) }