3
0

Unread refactoring, moving logic to backend

This commit is contained in:
Denis Arh
2019-07-19 20:47:02 +02:00
parent 8d7fb9c814
commit ba3b59acd3
12 changed files with 414 additions and 111 deletions

View File

@@ -40,7 +40,7 @@ func Message(ctx context.Context, msg *messagingTypes.Message) *outgoing.Message
ReplyTo: msg.ReplyTo,
Replies: msg.Replies,
RepliesFrom: Uint64stoa(msg.RepliesFrom),
Unread: Unread(msg.Unread),
Unread: MessageUnread(msg.Unread),
Attachment: Attachment(msg.Attachment, currentUserID),
Mentions: messageMentionSet(msg.Mentions),
@@ -145,7 +145,7 @@ func Channel(ch *messagingTypes.Channel) *outgoing.Channel {
Type: string(ch.Type),
MembershipFlag: string(flag),
Members: Uint64stoa(ch.Members),
Unread: Unread(ch.Unread),
Unread: ChannelUnread(ch.Unread),
CanJoin: ch.CanJoin,
CanPart: ch.CanPart,
@@ -196,10 +196,37 @@ func Unread(v *messagingTypes.Unread) *outgoing.Unread {
return nil
}
return &outgoing.Unread{
ChannelID: v.ChannelID,
ThreadID: v.ReplyTo,
LastMessageID: v.LastMessageID,
Count: v.Count,
ThreadCount: v.ThreadCount,
ThreadTotal: v.ThreadTotal,
}
}
func ChannelUnread(v *messagingTypes.Unread) *outgoing.Unread {
if v == nil || (v.Count == 0 && v.ThreadCount == 0) {
return nil
}
return &outgoing.Unread{
LastMessageID: v.LastMessageID,
Count: v.Count,
ThreadCount: v.ThreadCount,
ThreadTotal: v.ThreadTotal,
}
}
func MessageUnread(v *messagingTypes.Unread) *outgoing.Unread {
if v == nil || v.Count == 0 {
return nil
}
return &outgoing.Unread{
LastMessageID: v.LastMessageID,
Count: v.Count,
InThreadCount: v.InThreadCount,
}
}

View File

@@ -19,6 +19,8 @@ type (
*Channel `json:"channel,omitempty"`
*ChannelSet `json:"channels,omitempty"`
*Unread `json:"unread,omitempty"`
*ChannelMember `json:"channelMember,omitempty"`
*ChannelMemberSet `json:"channelMembers,omitempty"`

View File

@@ -1,10 +1,20 @@
package outgoing
import "encoding/json"
type (
Unread struct {
// Channel to part (nil) for ALL channels
ChannelID uint64 `json:"channelID,string,omitempty"`
ThreadID uint64 `json:"threadID,string,omitempty"`
LastMessageID uint64 `json:"lastMessageID,string,omitempty"`
Count uint32 `json:"count"`
InThreadCount uint32 `json:"tcount,omitempty"`
ThreadCount uint32 `json:"threadCount"`
ThreadTotal uint32 `json:"threadTotal,omitempty"`
}
)
func (p *Unread) EncodeMessage() ([]byte, error) {
return json.Marshal(Payload{Unread: p})
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"github.com/titpetric/factory"
"gopkg.in/Masterminds/squirrel.v1"
"github.com/cortezaproject/corteza-server/internal/auth"
)
@@ -45,3 +46,21 @@ func (r *repository) db() *factory.DB {
}
return DB(r.ctx)
}
// Fetches single row from table
func (r repository) fetchSet(set interface{}, q squirrel.SelectBuilder) (err error) {
var (
sql string
args []interface{}
)
if sql, args, err = q.ToSql(); err != nil {
return
}
if err = r.db().Select(set, sql, args...); err != nil {
return
}
return
}

View File

@@ -3,8 +3,8 @@ package repository
import (
"context"
"github.com/jmoiron/sqlx"
"github.com/titpetric/factory"
"gopkg.in/Masterminds/squirrel.v1"
"github.com/cortezaproject/corteza-server/messaging/types"
)
@@ -14,7 +14,9 @@ type (
UnreadRepository interface {
With(ctx context.Context, db *factory.DB) UnreadRepository
Find(filter *types.UnreadFilter) (types.UnreadSet, error)
Count(userID, channelID uint64, threadIDs ...uint64) (types.UnreadSet, error)
CountThreads(userID, channelID uint64) (types.UnreadSet, error)
Preset(channelID, threadID uint64, userIDs ...uint64) (err error)
Record(userID, channelID, threadID, lastReadMessageID uint64, count uint32) error
Inc(channelID, replyTo, userID uint64) error
@@ -30,17 +32,6 @@ type (
)
const (
// Fetching channel members of all channels a specific user has access to
sqlUnreadSelect = `SELECT rel_channel, rel_reply_to, rel_user, count, rel_last_message
FROM messaging_unread
WHERE count > 0 && rel_last_message > 0 `
// Fetching channel members of all channels a specific user has access to
sqlThreadUnreadSelect = `SELECT rel_channel, sum(count) as count
FROM messaging_unread
WHERE rel_user = ? AND rel_reply_to > 0
GROUP BY rel_channel`
sqlUnreadIncCount = `UPDATE messaging_unread
SET count = count + 1
WHERE rel_channel = ? AND rel_reply_to = ? AND rel_user <> ?`
@@ -49,6 +40,8 @@ const (
SET count = count - 1
WHERE rel_channel = ? AND rel_reply_to = ? AND count > 0`
sqlResetCount = `REPLACE INTO messaging_unread (rel_channel, rel_reply_to, rel_user, count) VALUES (?, ?, ?, 0)`
sqlUnreadPresetChannel = `INSERT IGNORE INTO messaging_unread (rel_channel, rel_reply_to, rel_user) VALUES (?, ?, ?)`
sqlUnreadPresetThreads = `INSERT IGNORE INTO messaging_unread (rel_channel, rel_reply_to, rel_user)
SELECT rel_channel, id, ?
@@ -62,6 +55,10 @@ func Unread(ctx context.Context, db *factory.DB) UnreadRepository {
return (&unread{}).With(ctx, db)
}
func (r unread) table() string {
return "messaging_unread"
}
// With context...
func (r *unread) With(ctx context.Context, db *factory.DB) UnreadRepository {
return &unread{
@@ -69,55 +66,84 @@ func (r *unread) With(ctx context.Context, db *factory.DB) UnreadRepository {
}
}
// Find unread info
func (r *unread) Find(filter *types.UnreadFilter) (uu types.UnreadSet, err error) {
params := make([]interface{}, 0)
sql := sqlUnreadSelect
// Count returns counts unread channel info
func (r *unread) Count(userID, channelID uint64, threadIDs ...uint64) (types.UnreadSet, error) {
var (
uu = types.UnreadSet{}
q = squirrel.
Select().
From(r.table()).
Columns(
"rel_channel",
"rel_last_message",
"rel_user",
"rel_reply_to",
"count")
)
if filter != nil {
if filter.UserID > 0 {
// scope: only channel we have access to
sql += ` AND rel_user = ?`
params = append(params, filter.UserID)
}
if filter.ChannelID > 0 {
// scope: only channel we have access to
sql += ` AND rel_channel = ?`
params = append(params, filter.ChannelID)
}
if len(filter.ThreadIDs) > 0 {
sql += ` AND rel_reply_to IN (?)`
params = append(params, filter.ThreadIDs)
} else {
sql += ` AND rel_reply_to = 0`
}
if userID > 0 {
q = q.Where("rel_user = ?", userID)
}
if sql, params, err = sqlx.In(sql, params...); err != nil {
if channelID > 0 {
q = q.Where("rel_channel = ?", channelID)
}
if len(threadIDs) == 0 {
q = q.Where("rel_reply_to = 0")
} else {
q = q.Where(squirrel.Eq{"rel_reply_to": threadIDs})
}
return uu, r.fetchSet(&uu, q)
}
// CountReplies counts unread thread info
func (r unread) CountThreads(userID, channelID uint64) (types.UnreadSet, error) {
type (
u struct {
Rel_channel, Rel_user uint64
Total, Count uint32
}
)
var (
err error
uu = types.UnreadSet{}
temp = []*u{}
q = squirrel.
Select().
From(r.table()).
Columns(
"rel_channel",
"rel_user",
"sum(count) AS count",
"sum(CASE WHEN count > 0 THEN 1 ELSE 0 END) AS total").
Where("rel_reply_to > 0 AND count > 0").
GroupBy("rel_channel", "rel_user")
)
if userID > 0 {
q = q.Where("rel_user = ?", userID)
}
if channelID > 0 {
q = q.Where("rel_channel = ?", channelID)
}
err = r.fetchSet(&temp, q)
if err != nil {
return nil, err
} else if err = r.db().Select(&uu, sql, params...); err != nil {
return nil, err
} else if len(filter.ThreadIDs) == 0 && filter.UserID > 0 {
// Check for unread thread messages
}
// We'll abuse Unread/UnreadSet
tt := types.UnreadSet{}
err = r.db().Select(&tt, sqlThreadUnreadSelect, filter.UserID)
_ = tt.Walk(func(t *types.Unread) error {
c := uu.FindByChannelId(t.ChannelID)
if c != nil {
c.InThreadCount = t.Count
} else {
// No un-reads in channel but we have them in threads (of that channel)
// swap values and append
t.InThreadCount, t.Count = t.Count, 0
uu = append(uu, t)
}
return nil
for _, t := range temp {
uu = append(uu, &types.Unread{
ChannelID: t.Rel_channel,
UserID: t.Rel_user,
ThreadCount: t.Count,
ThreadTotal: t.Total,
})
}
@@ -171,15 +197,27 @@ func (r *unread) Record(userID, channelID, threadID, lastReadMessageID uint64, c
}
// Inc increments unread message count on a channel/thread for all but one user
func (r *unread) Inc(channelID, threadID, userID uint64) error {
_, err := r.db().Exec(sqlUnreadIncCount, channelID, threadID, userID)
return err
func (r *unread) Inc(channelID, threadID, userID uint64) (err error) {
_, err = r.db().Exec(sqlUnreadIncCount, channelID, threadID, userID)
if err != nil {
return err
}
return nil
}
// Dec decrements unread message count on a channel/thread for all but one user
func (r *unread) Dec(channelID, threadID, userID uint64) error {
_, err := r.db().Exec(sqlUnreadDecCount, channelID, threadID)
return err
func (r *unread) Dec(channelID, threadID, userID uint64) (err error) {
_, err = r.db().Exec(sqlUnreadDecCount, channelID, threadID)
if err != nil {
return err
}
_, err = r.db().Exec(sqlResetCount, channelID, threadID, userID)
if err != nil {
return err
}
return nil
}
func (r *unread) CountOwned(userID uint64) (c int, err error) {

View File

@@ -170,17 +170,41 @@ func (svc *channel) preloadMembers(cc types.ChannelSet) (err error) {
return
}
// preload channel unread info for a single user
func (svc *channel) preloadUnreads(cc types.ChannelSet) error {
var userID = auth.GetIdentityFromContext(svc.ctx).Identity()
if vv, err := svc.unread.Find(&types.UnreadFilter{UserID: userID}); err != nil {
if uu, err := svc.unread.Count(userID, 0); err != nil {
return err
} else {
return cc.Walk(func(ch *types.Channel) error {
ch.Unread = vv.FindByChannelId(ch.ID)
_ = cc.Walk(func(ch *types.Channel) error {
ch.Unread = uu.FindByChannelId(ch.ID)
return nil
})
}
if uu, err := svc.unread.CountThreads(userID, 0); err != nil {
return err
} else {
_ = cc.Walk(func(ch *types.Channel) error {
var u = uu.FindByChannelId(ch.ID)
if u == nil {
return nil
}
if ch.Unread == nil {
ch.Unread = &types.Unread{}
}
ch.Unread.ThreadCount = u.ThreadCount
ch.Unread.ThreadTotal = u.ThreadTotal
return nil
})
}
return nil
}
// FindMembers loads all members (and full users) for a specific channel

View File

@@ -24,6 +24,7 @@ type (
Activity(a *types.Activity) error
Message(m *types.Message) error
MessageFlag(m *types.MessageFlag) error
UnreadCounters(uu types.UnreadSet) error
Channel(m *types.Channel) error
Join(userID, channelID uint64) error
Part(userID, channelID uint64) error
@@ -85,6 +86,12 @@ func (svc event) MessageFlag(f *types.MessageFlag) error {
return nil
}
func (svc event) UnreadCounters(uu types.UnreadSet) error {
return uu.Walk(func(u *types.Unread) error {
return svc.push(payload.Unread(u), types.EventQueueItemSubTypeUser, u.UserID)
})
}
// Channel notifies subscribers about channel change
//
// If this is a public channel we notify everyone
@@ -106,7 +113,7 @@ func (svc event) Join(userID, channelID uint64) (err error) {
join := payload.ChannelJoin(channelID, userID)
// Subscribe user to the channel
if err = svc.push(join, types.EventQueueItemSubTypeUser, 0); err != nil {
if err = svc.push(join, types.EventQueueItemSubTypeUser, userID); err != nil {
return
}
@@ -131,7 +138,7 @@ func (svc event) Part(userID, channelID uint64) (err error) {
}
// Subscribe user to the channel
if err = svc.push(part, types.EventQueueItemSubTypeUser, 0); err != nil {
if err = svc.push(part, types.EventQueueItemSubTypeUser, userID); err != nil {
return
}

View File

@@ -189,7 +189,7 @@ func (svc message) filterMessagesByAccessibleChannels(mm types.MessageSet) types
return mm
}
func (svc message) Create(in *types.Message) (message *types.Message, err error) {
func (svc message) Create(in *types.Message) (m *types.Message, err error) {
if in == nil {
in = &types.Message{}
}
@@ -211,7 +211,7 @@ func (svc message) Create(in *types.Message) (message *types.Message, err error)
in.UserID = auth.GetIdentityFromContext(svc.ctx).Identity()
}
return message, svc.db.Transaction(func() (err error) {
return m, svc.db.Transaction(func() (err error) {
// Broadcast queue
var bq = types.MessageSet{}
var ch *types.Channel
@@ -279,19 +279,21 @@ func (svc message) Create(in *types.Message) (message *types.Message, err error)
return ErrNoPermissions.withStack()
}
if message, err = svc.message.Create(in); err != nil {
if m, err = svc.message.Create(in); err != nil {
return
}
if err = svc.updateMentions(message.ID, svc.extractMentions(message)); err != nil {
mentions := svc.extractMentions(m)
if err = svc.updateMentions(m.ID, mentions); err != nil {
return
}
if err = svc.unread.Inc(message.ChannelID, message.ReplyTo, message.UserID); err != nil {
return
}
svc.sendNotifications(m, mentions)
return svc.sendEvent(append(bq, message)...)
// Count unreads in the background and send updates to all users
svc.countUnreads(ch, m, 0)
return svc.sendEvent(append(bq, m)...)
})
}
@@ -413,17 +415,16 @@ func (svc message) Delete(messageID uint64) error {
return
}
if err = svc.unread.Dec(deletedMsg.ChannelID, deletedMsg.ReplyTo, deletedMsg.UserID); err != nil {
return err
} else {
// Set deletedAt timestamp so that our clients can react properly...
deletedMsg.DeletedAt = timeNowPtr()
}
// Set deletedAt timestamp so that our clients can react properly...
deletedMsg.DeletedAt = timeNowPtr()
if err = svc.updateMentions(messageID, nil); err != nil {
return
}
// Count unreads in the background and send updates to all users
svc.countUnreads(ch, deletedMsg, 0)
return svc.sendEvent(append(bq, deletedMsg)...)
})
}
@@ -435,7 +436,7 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui
var (
currentUserID uint64 = repository.Identity(svc.ctx)
count uint32
tcount uint32
threadCount uint32
err error
)
@@ -463,9 +464,13 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui
// This is request for channel,
// count all thread unreads
var uu types.UnreadSet
uu, err = svc.unread.Find(&types.UnreadFilter{UserID: currentUserID, ChannelID: channelID})
uu, err = svc.unread.CountThreads(currentUserID, channelID)
if err != nil {
return err
}
if u := uu.FindByChannelId(channelID); u != nil {
tcount = u.InThreadCount
threadCount = u.ThreadCount
}
}
@@ -497,10 +502,17 @@ func (svc message) MarkAsRead(channelID, threadID, lastReadMessageID uint64) (ui
}
err = svc.unread.Record(currentUserID, channelID, threadID, lastReadMessageID, count)
return errors.Wrap(err, "unable to record unread messages")
if err != nil {
return errors.Wrap(err, "unable to record unread messages")
}
// Re-count unreads and send updates to this user
svc.countUnreads(ch, nil, currentUserID)
return nil
})
return lastReadMessageID, count, tcount, errors.Wrap(err, "unable to mark as read")
return lastReadMessageID, count, threadCount, errors.Wrap(err, "unable to mark as read")
}
// React on a message with an emoji
@@ -599,9 +611,7 @@ func (svc message) flag(messageID uint64, flag string, remove bool) error {
return
}
// @todo: log possible error
svc.sendFlagEvent(f)
_ = svc.sendFlagEvent(f)
return
})
@@ -706,7 +716,7 @@ func (svc message) preloadUnreads(mm types.MessageSet) error {
return nil
}
if vv, err := svc.unread.Find(&types.UnreadFilter{UserID: userID, ThreadIDs: mm.IDs()}); err != nil {
if vv, err := svc.unread.Count(userID, 0, mm.IDs()...); err != nil {
return err
} else {
return mm.Walk(func(m *types.Message) error {
@@ -731,6 +741,90 @@ func (svc message) sendEvent(mm ...*types.Message) (err error) {
return
}
// Generates and sends notifications from the new message
//
//
func (svc message) sendNotifications(message *types.Message, mentions types.MentionSet) {
// @todo implementation
}
// countUnreads orchestrates unread-related operations (inc/dec, (re)counting & sending events)
//
// 1. increases/decreases unread counters for channel or thread
// 2. collects all counters for channel or thread
// 3. sends unread events to subscribers
func (svc message) countUnreads(ch *types.Channel, m *types.Message, userID uint64) {
var (
err error
uuBase, uuThreads, uuChannels types.UnreadSet
// mm types.ChannelMemberSet
threadIDs []uint64
)
if m != nil {
if m.DeletedAt != nil {
// When deleting message, all existing counters are decreased!
if err = svc.unread.Dec(m.ChannelID, m.ReplyTo, m.UserID); err != nil {
svc.logger.With(zap.Error(err)).Info("could not decrement unread counter")
return
}
} else if m.UpdatedAt == nil {
// Reset user's counter and set current message ID as last read.
err = svc.unread.Record(
m.UserID,
m.ChannelID,
m.ReplyTo,
m.ID,
0,
)
// When new message is created, update all existing counters
if err = svc.unread.Inc(m.ChannelID, m.ReplyTo, m.UserID); err != nil {
svc.logger.With(zap.Error(err)).Info("could not increment unread counter")
return
}
}
if m.ReplyTo > 0 {
threadIDs = []uint64{m.ReplyTo}
}
}
uuBase, err = svc.unread.Count(userID, ch.ID, threadIDs...)
if err != nil {
svc.logger.With(zap.Error(err)).Info("could not count unread messages")
return
}
if len(threadIDs) > 0 {
// If base count was done for a thread,
// Do another count for channel
uuChannels, err = svc.unread.Count(userID, ch.ID)
if err != nil {
svc.logger.With(zap.Error(err)).Info("could not count unread messages")
return
}
uuBase = uuBase.Merge(uuChannels)
// Now recount all threads for this channel
uuThreads, err = svc.unread.CountThreads(userID, ch.ID)
if err != nil {
svc.logger.With(zap.Error(err)).Info("could not count unread messages")
return
}
uuBase = uuBase.Merge(uuThreads)
}
// This is a reply, make sure we fetch the new stats about unread replies and push them to users
err = svc.event.UnreadCounters(uuBase)
if err != nil {
svc.logger.With(zap.Error(err)).Info("could not send unread count event")
return
}
}
// Sends message to event loop
func (svc message) sendFlagEvent(ff ...*types.MessageFlag) (err error) {
for _, f := range ff {

View File

@@ -68,7 +68,7 @@ func (ctrl *Message) MarkAsRead(ctx context.Context, r *request.MessageMarkAsRea
return outgoing.Unread{
LastMessageID: messageID,
Count: count,
InThreadCount: tcount,
ThreadCount: tcount,
}, err
}

View File

@@ -7,13 +7,38 @@ type (
UserID uint64 `db:"rel_user"`
LastMessageID uint64 `db:"rel_last_message"`
Count uint32 `db:"count"`
InThreadCount uint32 `db:"-"`
}
UnreadFilter struct {
UserID uint64
ChannelID uint64
ThreadIDs []uint64
Count uint32 `db:"count"`
ThreadCount uint32 `db:"-"`
ThreadTotal uint32 `db:"-"`
}
)
func (uu UnreadSet) Merge(in UnreadSet) UnreadSet {
var (
out = append(UnreadSet{}, uu...)
olen = len(out)
)
inSet:
for _, i := range in {
for o := 0; o < olen; o++ {
if out[o].UserID == i.UserID && out[o].ChannelID == i.ChannelID && out[o].ReplyTo == i.ReplyTo {
if i.Count > 0 {
out[o].Count = i.Count
}
if i.ThreadCount > 0 {
out[o].ThreadCount = i.ThreadCount
}
if i.ThreadTotal > 0 {
out[o].ThreadTotal = i.ThreadTotal
}
continue inSet
}
}
out = append(out, i)
}
return out
}

View File

@@ -0,0 +1,44 @@
package types
import (
"reflect"
"testing"
)
func TestUnreadSet_Merge(t *testing.T) {
type args struct {
in UnreadSet
}
tests := []struct {
name string
uu UnreadSet
args args
want UnreadSet
}{
{
name: "simple",
uu: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}},
args: args{in: UnreadSet{&Unread{ChannelID: 1, UserID: 1, ThreadCount: 3, ThreadTotal: 4}}},
want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2, ThreadCount: 3, ThreadTotal: 4}},
},
{
name: "empty base",
uu: UnreadSet{},
args: args{in: UnreadSet{&Unread{ChannelID: 1, UserID: 1, ThreadCount: 3, ThreadTotal: 4}}},
want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 0, ThreadCount: 3, ThreadTotal: 4}},
},
{
name: "emmpt input",
uu: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}},
args: args{in: nil},
want: UnreadSet{&Unread{ChannelID: 1, UserID: 1, Count: 2}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.uu.Merge(tt.args.in); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Merge() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -3,6 +3,7 @@ package websocket
import (
"context"
"encoding/json"
"errors"
sentry "github.com/getsentry/sentry-go"
"github.com/titpetric/factory"
@@ -57,6 +58,10 @@ func (eq *eventQueue) store(ctx context.Context, qp repository.EventsRepository)
}
func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepository, store eventQueueWalker) error {
var (
userID uint64
)
for {
item, err := qp.Pull(ctx)
if err != nil {
@@ -64,6 +69,11 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo
}
if item.SubType == types.EventQueueItemSubTypeUser {
userID = payload.ParseUInt64(item.Subscriber)
if userID == 0 {
return errors.New("subscriber could not be parsed as uint64")
}
p := &outgoing.Payload{}
if err := json.Unmarshal(item.Payload, p); err != nil {
@@ -76,7 +86,7 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo
// This store.Walk handler does not send to subscribed sessions but
// subscribes all sessions that belong to the same user
store.Walk(func(s *Session) {
if payload.Uint64toa(s.user.Identity()) == p.ChannelJoin.UserID {
if s.user.Identity() == userID {
s.subs.Add(p.ChannelJoin.ID)
}
})
@@ -86,25 +96,28 @@ func (eq *eventQueue) feedSessions(ctx context.Context, qp repository.EventsRepo
// This store.Walk handler does not send to subscribed sessions but
// subscribes all sessions that belong to the same user
store.Walk(func(s *Session) {
if payload.Uint64toa(s.user.Identity()) == p.ChannelPart.UserID {
if s.user.Identity() == userID {
s.subs.Delete(p.ChannelPart.ID)
}
})
} else {
// No other payload types we can handle at the moment.
return nil
store.Walk(func(s *Session) {
if s.user.Identity() == userID {
_ = s.sendBytes(item.Payload)
}
})
}
} else if item.Subscriber == "" {
// Distribute payload to all connected sessions
store.Walk(func(s *Session) {
s.sendBytes(item.Payload)
_ = s.sendBytes(item.Payload)
})
} else {
// Distribute payload to specific subscribers
store.Walk(func(s *Session) {
if s.subs.Get(item.Subscriber) != nil {
s.sendBytes(item.Payload)
_ = s.sendBytes(item.Payload)
}
})
}