Unread refactoring, moving logic to backend
This commit is contained in:
@@ -40,7 +40,7 @@ func Message(ctx context.Context, msg *messagingTypes.Message) *outgoing.Message
|
||||
ReplyTo: msg.ReplyTo,
|
||||
Replies: msg.Replies,
|
||||
RepliesFrom: Uint64stoa(msg.RepliesFrom),
|
||||
Unread: Unread(msg.Unread),
|
||||
Unread: MessageUnread(msg.Unread),
|
||||
|
||||
Attachment: Attachment(msg.Attachment, currentUserID),
|
||||
Mentions: messageMentionSet(msg.Mentions),
|
||||
@@ -145,7 +145,7 @@ func Channel(ch *messagingTypes.Channel) *outgoing.Channel {
|
||||
Type: string(ch.Type),
|
||||
MembershipFlag: string(flag),
|
||||
Members: Uint64stoa(ch.Members),
|
||||
Unread: Unread(ch.Unread),
|
||||
Unread: ChannelUnread(ch.Unread),
|
||||
|
||||
CanJoin: ch.CanJoin,
|
||||
CanPart: ch.CanPart,
|
||||
@@ -196,10 +196,37 @@ func Unread(v *messagingTypes.Unread) *outgoing.Unread {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &outgoing.Unread{
|
||||
ChannelID: v.ChannelID,
|
||||
ThreadID: v.ReplyTo,
|
||||
LastMessageID: v.LastMessageID,
|
||||
Count: v.Count,
|
||||
ThreadCount: v.ThreadCount,
|
||||
ThreadTotal: v.ThreadTotal,
|
||||
}
|
||||
}
|
||||
|
||||
func ChannelUnread(v *messagingTypes.Unread) *outgoing.Unread {
|
||||
if v == nil || (v.Count == 0 && v.ThreadCount == 0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &outgoing.Unread{
|
||||
LastMessageID: v.LastMessageID,
|
||||
Count: v.Count,
|
||||
ThreadCount: v.ThreadCount,
|
||||
ThreadTotal: v.ThreadTotal,
|
||||
}
|
||||
}
|
||||
|
||||
func MessageUnread(v *messagingTypes.Unread) *outgoing.Unread {
|
||||
if v == nil || v.Count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &outgoing.Unread{
|
||||
LastMessageID: v.LastMessageID,
|
||||
Count: v.Count,
|
||||
InThreadCount: v.InThreadCount,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ type (
|
||||
*Channel `json:"channel,omitempty"`
|
||||
*ChannelSet `json:"channels,omitempty"`
|
||||
|
||||
*Unread `json:"unread,omitempty"`
|
||||
|
||||
*ChannelMember `json:"channelMember,omitempty"`
|
||||
*ChannelMemberSet `json:"channelMembers,omitempty"`
|
||||
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
package outgoing
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type (
|
||||
Unread struct {
|
||||
// Channel to part (nil) for ALL channels
|
||||
ChannelID uint64 `json:"channelID,string,omitempty"`
|
||||
ThreadID uint64 `json:"threadID,string,omitempty"`
|
||||
|
||||
LastMessageID uint64 `json:"lastMessageID,string,omitempty"`
|
||||
Count uint32 `json:"count"`
|
||||
InThreadCount uint32 `json:"tcount,omitempty"`
|
||||
|
||||
ThreadCount uint32 `json:"threadCount"`
|
||||
ThreadTotal uint32 `json:"threadTotal,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
func (p *Unread) EncodeMessage() ([]byte, error) {
|
||||
return json.Marshal(Payload{Unread: p})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/titpetric/factory"
|
||||
"gopkg.in/Masterminds/squirrel.v1"
|
||||
|
||||
"github.com/cortezaproject/corteza-server/internal/auth"
|
||||
)
|
||||
@@ -45,3 +46,21 @@ func (r *repository) db() *factory.DB {
|
||||
}
|
||||
return DB(r.ctx)
|
||||
}
|
||||
|
||||
// Fetches single row from table
|
||||
func (r repository) fetchSet(set interface{}, q squirrel.SelectBuilder) (err error) {
|
||||
var (
|
||||
sql string
|
||||
args []interface{}
|
||||
)
|
||||
|
||||
if sql, args, err = q.ToSql(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = r.db().Select(set, sql, args...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/titpetric/factory"
|
||||
"gopkg.in/Masterminds/squirrel.v1"
|
||||
|
||||
"github.com/cortezaproject/corteza-server/messaging/types"
|
||||
)
|
||||
@@ -14,7 +14,9 @@ type (
|
||||
UnreadRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) UnreadRepository
|
||||
|
||||
Find(filter *types.UnreadFilter) (types.UnreadSet, error)
|
||||
Count(userID, channelID uint64, threadIDs ...uint64) (types.UnreadSet, error)
|
||||
CountThreads(userID, channelID uint64) (types.UnreadSet, error)
|
||||
|
||||
Preset(channelID, threadID uint64, userIDs ...uint64) (err error)
|
||||
Record(userID, channelID, threadID, lastReadMessageID uint64, count uint32) error
|
||||
Inc(channelID, replyTo, userID uint64) error
|
||||
@@ -30,17 +32,6 @@ type (
|
||||
)
|
||||
|
||||
const (
|
||||
// Fetching channel members of all channels a specific user has access to
|
||||
sqlUnreadSelect = `SELECT rel_channel, rel_reply_to, rel_user, count, rel_last_message
|
||||
FROM messaging_unread
|
||||
WHERE count > 0 && rel_last_message > 0 `
|
||||
|
||||
// Fetching channel members of all channels a specific user has access to
|
||||
sqlThreadUnreadSelect = `SELECT rel_channel, sum(count) as count
|
||||
FROM messaging_unread
|
||||
WHERE rel_user = ? AND rel_reply_to > 0
|
||||
GROUP BY rel_channel`
|
||||
|
||||
sqlUnreadIncCount = `UPDATE messaging_unread
|
||||
SET count = count + 1
|
||||
WHERE rel_channel = ? AND rel_reply_to = ? AND rel_user <> ?`
|
||||
@@ -49,6 +40,8 @@ const (
|
||||
SET count = count - 1
|
||||
WHERE rel_channel = ? AND rel_reply_to = ? AND count > 0`
|
||||
|
||||
sqlResetCount = `REPLACE INTO messaging_unread (rel_channel, rel_reply_to, rel_user, count) VALUES (?, ?, ?, 0)`
|
||||
|
||||
sqlUnreadPresetChannel = `INSERT IGNORE INTO messaging_unread (rel_channel, rel_reply_to, rel_user) VALUES (?, ?, ?)`
|
||||
sqlUnreadPresetThreads = `INSERT IGNORE INTO messaging_unread (rel_channel, rel_reply_to, rel_user)
|
||||
SELECT rel_channel, id, ?
|
||||
@@ -62,6 +55,10 @@ func Unread(ctx context.Context, db *factory.DB) UnreadRepository {
|
||||
return (&unread{}).With(ctx, db)
|
||||
}
|
||||
|
||||
func (r unread) table() string {
|
||||
return "messaging_unread"
|
||||
}
|
||||
|
||||
// With context...
|
||||
func (r *unread) With(ctx context.Context, db *factory.DB) UnreadRepository {
|
||||
return &unread{
|
||||
@@ -69,55 +66,84 @@ func (r *unread) With(ctx context.Context, db *factory.DB) UnreadRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// Find unread info
|
||||
func (r *unread) Find(filter *types.UnreadFilter) (uu types.UnreadSet, err error) {
|
||||
params := make([]interface{}, 0)
|
||||
sql := sqlUnreadSelect
|
||||
// Count returns counts unread channel info
|
||||
func (r *unread) Count(userID, channelID uint64, threadIDs ...uint64) (types.UnreadSet, error) {
|
||||
var (
|
||||
uu = types.UnreadSet{}
|
||||
q = squirrel.
|
||||
Select().
|
||||
From(r.table()).
|
||||
Columns(
|
||||
"rel_channel",
|
||||
"rel_last_message",
|
||||
"rel_user",
|
||||
"rel_reply_to",
|
||||
"count")
|
||||
)
|
||||
|
||||
if filter != nil {
|
||||
if filter.UserID > 0 {
|
||||
// scope: only channel we have access to
|
||||
sql += ` AND rel_user = ?`
|
||||
params = append(params, filter.UserID)
|
||||
}
|
||||
|
||||
if filter.ChannelID > 0 {
|
||||
// scope: only channel we have access to
|
||||
sql += ` AND rel_channel = ?`
|
||||
params = append(params, filter.ChannelID)
|
||||
}
|
||||
|
||||
if len(filter.ThreadIDs) > 0 {
|
||||
sql += ` AND rel_reply_to IN (?)`
|
||||
params = append(params, filter.ThreadIDs)
|
||||
} else {
|
||||
sql += ` AND rel_reply_to = 0`
|
||||
}
|
||||
if userID > 0 {
|
||||
q = q.Where("rel_user = ?", userID)
|
||||
}
|
||||
|
||||
if sql, params, err = sqlx.In(sql, params...); err != nil {
|
||||
if channelID > 0 {
|
||||
q = q.Where("rel_channel = ?", channelID)
|
||||
}
|
||||
|
||||
if len(threadIDs) == 0 {
|
||||
q = q.Where("rel_reply_to = 0")
|
||||
} else {
|
||||
q = q.Where(squirrel.Eq{"rel_reply_to": threadIDs})
|
||||
}
|
||||
|
||||
return uu, r.fetchSet(&uu, q)
|
||||
}
|
||||
|
||||
// CountReplies counts unread thread info
|
||||
func (r unread) CountThreads(userID, channelID uint64) (types.UnreadSet, error) {
|
||||
type (
|
||||
u struct {
|
||||
Rel_channel, Rel_user uint64
|
||||
Total, Count uint32
|
||||
}
|
||||
)
|
||||
var (
|
||||
err error
|
||||
|
||||
uu = types.UnreadSet{}
|
||||
|
||||
temp = []*u{}
|
||||
|
||||
q = squirrel.
|
||||
Select().
|
||||
From(r.table()).
|
||||
Columns(
|
||||
"rel_channel",
|
||||
"rel_user",
|
||||
"sum(count) AS count",
|
||||
"sum(CASE WHEN count > 0 THEN 1 ELSE 0 END) AS total").
|
||||
Where("rel_reply_to > 0 AND count > 0").
|
||||
GroupBy("rel_channel", "rel_user")
|
||||
)
|
||||
|
||||
if userID > 0 {
|
||||
q = q.Where("rel_user = ?", userID)
|
||||
}
|
||||
|
||||
if channelID > 0 {
|
||||
q = q.Where("rel_channel = ?", channelID)
|
||||
}
|
||||
|
||||
err = r.fetchSet(&temp, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if err = r.db().Select(&uu, sql, params...); err != nil {
|
||||
return nil, err
|
||||
} else if len(filter.ThreadIDs) == 0 && filter.UserID > 0 {
|
||||
// Check for unread thread messages
|
||||
}
|
||||
|
||||
// We'll abuse Unread/UnreadSet
|
||||
tt := types.UnreadSet{}
|
||||
|
||||
err = r.db().Select(&tt, sqlThreadUnreadSelect, filter.UserID)
|
||||
|
||||
_ = tt.Walk(func(t *types.Unread) error {
|
||||
c := uu.FindByChannelId(t.ChannelID)
|
||||
if c != nil {
|
||||
c.InThreadCount = t.Count
|
||||
} else {
|
||||
// No un-reads in channel but we have them in threads (of that channel)
|
||||
// swap values and append
|
||||
t.InThreadCount, t.Count = t.Count, 0
|
||||
uu = append(uu, t)
|
||||
}
|
||||
return nil
|
||||
for _, t := range temp {
|
||||
uu = append(uu, &types.Unread{
|
||||
ChannelID: t.Rel_channel,
|
||||
UserID: t.Rel_user,
|
||||
ThreadCount: t.Count,
|
||||
ThreadTotal: t.Total,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -171,15 +197,27 @@ func (r *unread) Record(userID, channelID, threadID, lastReadMessageID uint64, c
|
||||
}
|
||||
|
||||
// Inc increments unread message count on a channel/thread for all but one user
|
||||
func (r *unread) Inc(channelID, threadID, userID uint64) error {
|
||||
_, err := r.db().Exec(sqlUnreadIncCount, channelID, threadID, userID)
|
||||
return err
|
||||
func (r *unread) Inc(channelID, threadID, userID uint64) (err error) {
|
||||
_, err = r.db().Exec(sqlUnreadIncCount, channelID, threadID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dec decrements unread message count on a channel/thread for all but one user
|
||||
func (r *unread) Dec(channelID, threadID, userID uint64) error {
|
||||
_, err := r.db().Exec(sqlUnreadDecCount, channelID, threadID)
|
||||
return err
|
||||
func (r *unread) Dec(channelID, threadID, userID uint64) (err error) {
|
||||
_, err = r.db().Exec(sqlUnreadDecCount, channelID, threadID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.db().Exec(sqlResetCount, channelID, threadID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *unread) CountOwned(userID uint64) (c int, err error) {
|
||||
|
||||
@@ -170,17 +170,41 @@ func (svc *channel) preloadMembers(cc types.ChannelSet) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// preload channel unread info for a single user
|
||||
func (svc *channel) preloadUnreads(cc types.ChannelSet) error {
|
||||
var userID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
if vv, err := svc.unread.Find(&types.UnreadFilter{UserID: userID}); err != nil {
|
||||
if uu, err := svc.unread.Count(userID, 0); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return cc.Walk(func(ch *types.Channel) error {
|
||||
ch.Unread = vv.FindByChannelId(ch.ID)
|
||||
_ = cc.Walk(func(ch *types.Channel) error {
|
||||
ch.Unread = uu.FindByChannelId(ch.ID)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if uu, err := svc.unread.CountThreads(userID, 0); err != nil {
|
||||
return err
|
||||
} else {
|
||||
_ = cc.Walk(func(ch *types.Channel) error {
|
||||
var u = uu.FindByChannelId(ch.ID)
|
||||
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ch.Unread == nil {
|
||||
ch.Unread = &types.Unread{}
|
||||
}
|
||||
|
||||
ch.Unread.ThreadCount = u.ThreadCount
|
||||
ch.Unread.ThreadTotal = u.ThreadTotal
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindMembers loads all members (and full users) for a specific channel
|
||||
|
||||
@@ -24,6 +24,7 @@ type (
|
||||
Activity(a *types.Activity) error
|
||||
Message(m *types.Message) error
|
||||
MessageFlag(m *types.MessageFlag) error
|
||||
UnreadCounters(uu types.UnreadSet) error
|
||||
Channel(m *types.Channel) error
|
||||
Join(userID, channelID uint64) error
|
||||
Part(userID, channelID uint64) error
|
||||
@@ -85,6 +86,12 @@ func (svc event) MessageFlag(f *types.MessageFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc event) UnreadCounters(uu types.UnreadSet) error {
|
||||
return uu.Walk(func(u *types.Unread) error {
|
||||
return svc.push(payload.Unread(u), types.EventQueueItemSubTypeUser, u.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
// Channel notifies subscribers about channel change
|
||||
//
|
||||
// If this is a public channel we notify everyone
|
||||
@@ -106,7 +113,7 @@ func (svc event) Join(userID, channelID uint64) (err error) {
|
||||
join := payload.ChannelJoin(channelID, userID)
|
||||
|
||||
// Subscribe user to the channel
|
||||
if err = svc.push(join, types.EventQueueItemSubTypeUser, 0); err != nil {
|
||||
if err = svc.push(join, types.EventQueueItemSubTypeUser, userID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -131,7 +138,7 @@ func (svc event) Part(userID, channelID uint64) (err error) {
|
||||
}
|
||||
|
||||
// Subscribe user to the channel
|
||||
if err = svc.push(part, types.EventQueueItemSubTypeUser, 0); err != nil {
|
||||
if err = svc.push(part, types.EventQueueItemSubTypeUser, userID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -189,7 +189,7 @@ func (svc message) filterMessagesByAccessibleChannels(mm types.MessageSet) types
|
||||
return mm
|
||||
}
|
||||
|
||||
func (svc message) Create(in *types.Message) (message *types.Message, err error) {
|
||||
func (svc message) Create(in *types.Message) (m *types.Message, err error) {
|
||||
if in == nil {
|
||||
in = &types.Message{}
|
||||
}
|
||||
@@ -211,7 +211,7 @@ func (svc message) Create(in *types.Message) (message *types.Message, err error)
|
||||
in.UserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
}
|
||||
|
||||
return message, svc.db.Transaction(func() (err error) {
|
||||
return m, svc.db.Transaction(func() (err error) {
|
||||
// Broadcast queue
|
||||
var bq = types.MessageSet{}
|
||||
var ch *types.Channel
|
||||
@@ -279,19 +279,21 @@ func (svc message) Create(in *types.Message) (message *types.Message, err error)
|
||||
return ErrNoPermissions.withStack()
|
||||
}
|
||||
|
||||
if message, err = svc.message.Create(in); err != nil {
|
||||
if m, err = svc.message.Create(in); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = svc.updateMentions(message.ID, svc.extractMentions(message)); err != nil {
|
||||
mentions := svc.extractMentions(m)
|
||||
if err = svc.updateMentions(m.ID, mentions); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = svc.unread.Inc(message.ChannelID, message.ReplyTo, message.UserID); err != nil {
|
||||
return
|
||||
}
|
||||
svc.sendNotifications(m, mentions)
|
||||
|
||||
return svc.sendEvent(append(bq, message)...)
|
||||
// Count unreads in the background and send updates to all users
|
||||
svc.countUnreads(ch, m, 0)
|
||||
|
||||
return svc.sendEvent(append(bq, m)...)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -413,17 +415,16 @@ func (svc message) Delete(messageID uint64) error {
|
||||
return
|
||||
}
|
||||
|
||||
if err = svc.unread.Dec(deletedMsg.ChannelID, deletedMsg.ReplyTo, deletedMsg.UserID); err != nil {
|
||||
return err
|
||||
} else {
|
||||
// Set deletedAt timestamp so that our clients can react properly...
|
||||
deletedMsg.DeletedAt = timeNowPtr()
|
||||
}
|
||||
// Set deletedAt timestamp so that our clients can react properly...
|
||||
deletedMsg.DeletedAt = timeNowPtr()
|
||||
|
||||
if err = svc.updateMentions(messageID, nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Count unreads in the background and send updates to all users
|
||||
svc.countUnreads(ch, deletedMsg, 0)
|
||||
|
||||
return svc.sendEvent(append(bq, deletedMsg)...)
|
||||
})
|
||||
}
|
||||
@@ -435,7 +436,7 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui
|
||||
var (
|
||||
currentUserID uint64 = repository.Identity(svc.ctx)
|
||||
count uint32
|
||||
tcount uint32
|
||||
threadCount uint32
|
||||
err error
|
||||
)
|
||||
|
||||
@@ -463,9 +464,13 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui
|
||||
// This is request for channel,
|
||||
// count all thread unreads
|
||||
var uu types.UnreadSet
|
||||
uu, err = svc.unread.Find(&types.UnreadFilter{UserID: currentUserID, ChannelID: channelID})
|
||||
uu, err = svc.unread.CountThreads(currentUserID, channelID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if u := uu.FindByChannelId(channelID); u != nil {
|
||||
tcount = u.InThreadCount
|
||||
threadCount = u.ThreadCount
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,10 +502,17 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui
|
||||
}
|
||||
|
||||
err = svc.unread.Record(currentUserID, channelID, threadID, lastReadMessageID, count)
|
||||
return errors.Wrap(err, "unable to record unread messages")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "unable to record unread messages")
|
||||
}
|
||||
|
||||
// Re-count unreads and send updates to this user
|
||||
svc.countUnreads(ch, nil, currentUserID)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return lastReadMessageID, count, tcount, errors.Wrap(err, "unable to mark as read")
|
||||
return lastReadMessageID, count, threadCount, errors.Wrap(err, "unable to mark as read")
|
||||
}
|
||||
|
||||
// React on a message with an emoji
|
||||
@@ -599,9 +611,7 @@ func (svc message) flag(messageID uint64, flag string, remove bool) error {
|
||||
return
|
||||
}
|
||||
|
||||
// @todo: log possible error
|
||||
svc.sendFlagEvent(f)
|
||||
|
||||
_ = svc.sendFlagEvent(f)
|
||||
return
|
||||
})
|
||||
|
||||
@@ -706,7 +716,7 @@ func (svc message) preloadUnreads(mm types.MessageSet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vv, err := svc.unread.Find(&types.UnreadFilter{UserID: userID, ThreadIDs: mm.IDs()}); err != nil {
|
||||
if vv, err := svc.unread.Count(userID, 0, mm.IDs()...); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return mm.Walk(func(m *types.Message) error {
|
||||
@@ -731,6 +741,90 @@ func (svc message) sendEvent(mm ...*types.Message) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generates and sends notifications from the new message
|
||||
//
|
||||
//
|
||||
func (svc message) sendNotifications(message *types.Message, mentions types.MentionSet) {
|
||||
// @todo implementation
|
||||
}
|
||||
|
||||
// countUnreads orchestrates unread-related operations (inc/dec, (re)counting & sending events)
|
||||
//
|
||||
// 1. increases/decreases unread counters for channel or thread
|
||||
// 2. collects all counters for channel or thread
|
||||
// 3. sends unread events to subscribers
|
||||
func (svc message) countUnreads(ch *types.Channel, m *types.Message, userID uint64) {
|
||||
var (
|
||||
err error
|
||||
uuBase, uuThreads, uuChannels types.UnreadSet
|
||||
// mm types.ChannelMemberSet
|
||||
threadIDs []uint64
|
||||
)
|
||||
|
||||
if m != nil {
|
||||
if m.DeletedAt != nil {
|
||||
// When deleting message, all existing counters are decreased!
|
||||
if err = svc.unread.Dec(m.ChannelID, m.ReplyTo, m.UserID); err != nil {
|
||||
svc.logger.With(zap.Error(err)).Info("could not decrement unread counter")
|
||||
return
|
||||
}
|
||||
} else if m.UpdatedAt == nil {
|
||||
// Reset user's counter and set current message ID as last read.
|
||||
err = svc.unread.Record(
|
||||
m.UserID,
|
||||
m.ChannelID,
|
||||
m.ReplyTo,
|
||||
m.ID,
|
||||
0,
|
||||
)
|
||||
|
||||
// When new message is created, update all existing counters
|
||||
if err = svc.unread.Inc(m.ChannelID, m.ReplyTo, m.UserID); err != nil {
|
||||
svc.logger.With(zap.Error(err)).Info("could not increment unread counter")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if m.ReplyTo > 0 {
|
||||
threadIDs = []uint64{m.ReplyTo}
|
||||
}
|
||||
}
|
||||
|
||||
uuBase, err = svc.unread.Count(userID, ch.ID, threadIDs...)
|
||||
if err != nil {
|
||||
svc.logger.With(zap.Error(err)).Info("could not count unread messages")
|
||||
return
|
||||
}
|
||||
|
||||
if len(threadIDs) > 0 {
|
||||
// If base count was done for a thread,
|
||||
// Do another count for channel
|
||||
uuChannels, err = svc.unread.Count(userID, ch.ID)
|
||||
if err != nil {
|
||||
svc.logger.With(zap.Error(err)).Info("could not count unread messages")
|
||||
return
|
||||
}
|
||||
|
||||
uuBase = uuBase.Merge(uuChannels)
|
||||
|
||||
// Now recount all threads for this channel
|
||||
uuThreads, err = svc.unread.CountThreads(userID, ch.ID)
|
||||
if err != nil {
|
||||
svc.logger.With(zap.Error(err)).Info("could not count unread messages")
|
||||
return
|
||||
}
|
||||
|
||||
uuBase = uuBase.Merge(uuThreads)
|
||||
}
|
||||
|
||||
// This is a reply, make sure we fetch the new stats about unread replies and push them to users
|
||||
err = svc.event.UnreadCounters(uuBase)
|
||||
if err != nil {
|
||||
svc.logger.With(zap.Error(err)).Info("could not send unread count event")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Sends message to event loop
|
||||
func (svc message) sendFlagEvent(ff ...*types.MessageFlag) (err error) {
|
||||
for _, f := range ff {
|
||||
|
||||
@@ -68,7 +68,7 @@ func (ctrl *Message) MarkAsRead(ctx context.Context, r *request.MessageMarkAsRea
|
||||
return outgoing.Unread{
|
||||
LastMessageID: messageID,
|
||||
Count: count,
|
||||
InThreadCount: tcount,
|
||||
ThreadCount: tcount,
|
||||
}, err
|
||||
}
|
||||
|
||||
|
||||
@@ -7,13 +7,38 @@ type (
|
||||
UserID uint64 `db:"rel_user"`
|
||||
LastMessageID uint64 `db:"rel_last_message"`
|
||||
|
||||
Count uint32 `db:"count"`
|
||||
InThreadCount uint32 `db:"-"`
|
||||
}
|
||||
|
||||
UnreadFilter struct {
|
||||
UserID uint64
|
||||
ChannelID uint64
|
||||
ThreadIDs []uint64
|
||||
Count uint32 `db:"count"`
|
||||
ThreadCount uint32 `db:"-"`
|
||||
ThreadTotal uint32 `db:"-"`
|
||||
}
|
||||
)
|
||||
|
||||
func (uu UnreadSet) Merge(in UnreadSet) UnreadSet {
|
||||
var (
|
||||
out = append(UnreadSet{}, uu...)
|
||||
olen = len(out)
|
||||
)
|
||||
|
||||
inSet:
|
||||
for _, i := range in {
|
||||
for o := 0; o < olen; o++ {
|
||||
if out[o].UserID == i.UserID && out[o].ChannelID == i.ChannelID && out[o].ReplyTo == i.ReplyTo {
|
||||
if i.Count > 0 {
|
||||
out[o].Count = i.Count
|
||||
}
|
||||
if i.ThreadCount > 0 {
|
||||
out[o].ThreadCount = i.ThreadCount
|
||||
}
|
||||
if i.ThreadTotal > 0 {
|
||||
out[o].ThreadTotal = i.ThreadTotal
|
||||
}
|
||||
|
||||
continue inSet
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, i)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
44
messaging/types/unread_test.go
Normal file
44
messaging/types/unread_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnreadSet_Merge(t *testing.T) {
|
||||
type args struct {
|
||||
in UnreadSet
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
uu UnreadSet
|
||||
args args
|
||||
want UnreadSet
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
uu: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}},
|
||||
args: args{in: UnreadSet{&Unread{ChannelID: 1, UserID: 1, ThreadCount: 3, ThreadTotal: 4}}},
|
||||
want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2, ThreadCount: 3, ThreadTotal: 4}},
|
||||
},
|
||||
{
|
||||
name: "empty base",
|
||||
uu: UnreadSet{},
|
||||
args: args{in: UnreadSet{&Unread{ChannelID: 1, UserID: 1, ThreadCount: 3, ThreadTotal: 4}}},
|
||||
want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 0, ThreadCount: 3, ThreadTotal: 4}},
|
||||
},
|
||||
{
|
||||
name: "emmpt input",
|
||||
uu: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}},
|
||||
args: args{in: nil},
|
||||
want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.uu.Merge(tt.args.in); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Merge() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package websocket
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
sentry "github.com/getsentry/sentry-go"
|
||||
"github.com/titpetric/factory"
|
||||
@@ -57,6 +58,10 @@ func (eq *eventQueue) store(ctx context.Context, qp repository.EventsRepository)
|
||||
}
|
||||
|
||||
func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepository, store eventQueueWalker) error {
|
||||
var (
|
||||
userID uint64
|
||||
)
|
||||
|
||||
for {
|
||||
item, err := qp.Pull(ctx)
|
||||
if err != nil {
|
||||
@@ -64,6 +69,11 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo
|
||||
}
|
||||
|
||||
if item.SubType == types.EventQueueItemSubTypeUser {
|
||||
userID = payload.ParseUInt64(item.Subscriber)
|
||||
if userID == 0 {
|
||||
return errors.New("subscriber could not be parsed as uint64")
|
||||
}
|
||||
|
||||
p := &outgoing.Payload{}
|
||||
|
||||
if err := json.Unmarshal(item.Payload, p); err != nil {
|
||||
@@ -76,7 +86,7 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo
|
||||
// This store.Walk handler does not send to subscribed sessions but
|
||||
// subscribes all sessions that belong to the same user
|
||||
store.Walk(func(s *Session) {
|
||||
if payload.Uint64toa(s.user.Identity()) == p.ChannelJoin.UserID {
|
||||
if s.user.Identity() == userID {
|
||||
s.subs.Add(p.ChannelJoin.ID)
|
||||
}
|
||||
})
|
||||
@@ -86,25 +96,28 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo
|
||||
// This store.Walk handler does not send to subscribed sessions but
|
||||
// subscribes all sessions that belong to the same user
|
||||
store.Walk(func(s *Session) {
|
||||
if payload.Uint64toa(s.user.Identity()) == p.ChannelPart.UserID {
|
||||
if s.user.Identity() == userID {
|
||||
s.subs.Delete(p.ChannelPart.ID)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
// No other payload types we can handle at the moment.
|
||||
return nil
|
||||
store.Walk(func(s *Session) {
|
||||
if s.user.Identity() == userID {
|
||||
_ = s.sendBytes(item.Payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
} else if item.Subscriber == "" {
|
||||
// Distribute payload to all connected sessions
|
||||
store.Walk(func(s *Session) {
|
||||
s.sendBytes(item.Payload)
|
||||
_ = s.sendBytes(item.Payload)
|
||||
})
|
||||
} else {
|
||||
// Distribute payload to specific subscribers
|
||||
store.Walk(func(s *Session) {
|
||||
if s.subs.Get(item.Subscriber) != nil {
|
||||
s.sendBytes(item.Payload)
|
||||
_ = s.sendBytes(item.Payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user