From 926005d2cc1fe10ce227689391fc6cd00f790ead Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Sat, 28 Jul 2018 17:49:41 +0200 Subject: [PATCH] Rework channel-related commants, sending messages --- sam/websocket/incoming.go | 4 +-- sam/websocket/incoming/channel.go | 8 ++--- sam/websocket/incoming/payload.go | 6 ++-- sam/websocket/incoming_channel.go | 51 +++++++++++++++++++++++------- sam/websocket/incoming_message.go | 8 +++-- sam/websocket/outgoing.go | 33 ++++++++++++------- sam/websocket/outgoing/messages.go | 29 +++++++++++++++++ sam/websocket/outgoing/payload.go | 13 ++++++++ sam/websocket/payload.go | 16 ++++++++++ sam/websocket/subscriptions.go | 16 +++++----- 10 files changed, 141 insertions(+), 43 deletions(-) diff --git a/sam/websocket/incoming.go b/sam/websocket/incoming.go index f52ebc848..67dbddcfa 100644 --- a/sam/websocket/incoming.go +++ b/sam/websocket/incoming.go @@ -28,8 +28,8 @@ func (s *Session) dispatch(raw []byte) (err error) { return s.channelJoin(ctx, *p.ChannelJoin) case p.ChannelPart != nil: return s.channelPart(ctx, *p.ChannelPart) - case p.ChannelPart != nil: - return s.channelPartAll(ctx, *p.ChannelPartAll) + case p.ChannelList != nil: + return s.channelList(ctx, *p.ChannelList) case p.ChannelOpen != nil: return s.channelOpen(ctx, *p.ChannelOpen) diff --git a/sam/websocket/incoming/channel.go b/sam/websocket/incoming/channel.go index 54307412e..8e5b08973 100644 --- a/sam/websocket/incoming/channel.go +++ b/sam/websocket/incoming/channel.go @@ -1,16 +1,14 @@ package incoming type ( + ChannelList struct{} + ChannelJoin struct { ChannelID string `json:"cid"` } ChannelPart struct { - ChannelID string `json:"cid"` - } - - ChannelPartAll struct { - Leave bool `json:"leave"` + ChannelID *string `json:"cid"` } ChannelOpen struct { diff --git a/sam/websocket/incoming/payload.go b/sam/websocket/incoming/payload.go index c0f71d4dc..0caf8674e 100644 --- a/sam/websocket/incoming/payload.go +++ b/sam/websocket/incoming/payload.go @@ -2,9 +2,9 @@ package incoming type Payload struct { // Channel actions - *ChannelJoin `json:"chjoin"` - *ChannelPart `json:"chpart"` - *ChannelPartAll `json:"chpartall"` + *ChannelList `json:"chlist"` + *ChannelJoin `json:"chjoin"` + *ChannelPart `json:"chpart"` // Get channel message history *ChannelOpen `json:"chopen"` diff --git a/sam/websocket/incoming_channel.go b/sam/websocket/incoming_channel.go index 734fa9083..aecb7becd 100644 --- a/sam/websocket/incoming_channel.go +++ b/sam/websocket/incoming_channel.go @@ -2,30 +2,58 @@ package websocket import ( "context" + "github.com/crusttech/crust/auth" "github.com/crusttech/crust/sam/service" "github.com/crusttech/crust/sam/types" "github.com/crusttech/crust/sam/websocket/incoming" + "github.com/crusttech/crust/sam/websocket/outgoing" ) func (s *Session) channelJoin(ctx context.Context, p incoming.ChannelJoin) error { - var () - // @todo: check access to channel - s.subs.Add(p.ChannelID, &Subscription{}) + // @todo: check access / can we join this channel (should be done by service layer) + + s.subs.Add(p.ChannelID) + + // Telling all subscribers of the channel we're joining that we are joining. + var chJoin = &outgoing.ChannelJoin{ + ID: p.ChannelID, + UserID: uint64toa(auth.GetIdentityFromContext(ctx).GetID()), + } + + // Send to all channel subscribers + s.broadcast(chJoin, &p.ChannelID) + return nil } func (s *Session) channelPart(ctx context.Context, p incoming.ChannelPart) error { - var () - // @todo: check access to channel - s.subs.Delete(p.ChannelID) + // @todo: check access / can we part this channel? (should be done by service layer) + + // First, let's unsubscribe, so we don't hear echos + if p.ChannelID != nil { + s.subs.Delete(*p.ChannelID) + } else { + s.subs.DeleteAll() + } + + // This payload will tell everyone that we're parting from ALL channels + var chPart = &outgoing.ChannelPart{ + ID: p.ChannelID, + UserID: uint64toa(auth.GetIdentityFromContext(ctx).GetID()), + } + + s.broadcast(chPart, p.ChannelID) + return nil } -func (s *Session) channelPartAll(ctx context.Context, p incoming.ChannelPartAll) error { - if p.Leave { - s.subs.DeleteAll() +func (s *Session) channelList(ctx context.Context, p incoming.ChannelList) error { + channels, err := service.Channel().Find(ctx, nil) + if err != nil { + return err } - return nil + + return s.respond(payloadFromChannels(channels)) } func (s *Session) channelOpen(ctx context.Context, p incoming.ChannelOpen) error { @@ -40,5 +68,6 @@ func (s *Session) channelOpen(ctx context.Context, p incoming.ChannelOpen) error if err != nil { return err } - return s.sendMessage(payloadFromMessages(messages)) + + return s.respond(payloadFromMessages(messages)) } diff --git a/sam/websocket/incoming_message.go b/sam/websocket/incoming_message.go index 0ac5488eb..45b67abc2 100644 --- a/sam/websocket/incoming_message.go +++ b/sam/websocket/incoming_message.go @@ -20,7 +20,8 @@ func (s *Session) messageCreate(ctx context.Context, p incoming.MessageCreate) e if err != nil { return err } - return s.sendMessageChannel(uint64toa(msg.ChannelID), payloadFromMessage(msg)) + + return s.broadcast(payloadFromMessage(msg), &p.ChannelID) } func (s *Session) messageUpdate(ctx context.Context, p incoming.MessageUpdate) error { @@ -34,7 +35,8 @@ func (s *Session) messageUpdate(ctx context.Context, p incoming.MessageUpdate) e if err != nil { return err } - return s.sendMessageChannel(uint64toa(msg.ChannelID), &outgoing.MessageUpdate{ID: p.ID, Message: msg.Message}) + + return s.broadcast(&outgoing.MessageUpdate{ID: p.ID, Message: msg.Message}, &p.ID) } func (s *Session) messageDelete(ctx context.Context, p incoming.MessageDelete) error { @@ -46,5 +48,5 @@ func (s *Session) messageDelete(ctx context.Context, p incoming.MessageDelete) e return err } - return s.sendMessageChannel(p.ChannelID, &outgoing.MessageDelete{ID: p.ID}) + return s.broadcast(&outgoing.MessageDelete{ID: p.ID}, &p.ChannelID) } diff --git a/sam/websocket/outgoing.go b/sam/websocket/outgoing.go index 837c6f293..6f49efbf6 100644 --- a/sam/websocket/outgoing.go +++ b/sam/websocket/outgoing.go @@ -1,38 +1,49 @@ package websocket import ( - "log" "encoding/json" - "time" "github.com/crusttech/crust/sam/websocket/outgoing" + "log" + "time" ) -func (s *Session) sendMessageChannel(channelID string, message outgoing.PayloadType) error { +// Sends message to all connected subscribers/users +// +// If channelID is nil, it broadcasts payload to anyone connected, +// if it is set, it broadcasts only to clients subscribed to that channel +func (s *Session) broadcast(message outgoing.PayloadType, channelID *string) error { p := outgoing.Payload{}.New().Load(message) // encode message once and send bytes pb, err := json.Marshal(p) if err != nil { return err } + store.Walk(func(sess *Session) { // send message only to users with subscribed channels - if sess.subs.Get(channelID) != nil { - select { - case sess.send <- pb: - case <-time.After(2 * time.Millisecond): - log.Println("websocket.messageChannel send timeout") - } + if channelID == nil || sess.subs.Get(*channelID) != nil { + sess.sendBytes(pb) } }) + return nil } -func (s *Session) sendMessage(message outgoing.PayloadType) error { +func (s *Session) respond(message outgoing.PayloadType) error { p := outgoing.Payload{}.New().Load(message) select { case s.send <- p: case <-time.After(2 * time.Millisecond): - log.Println("websocket.messageChannel send timeout") + log.Println("websocket.respond send timeout") + } + return nil +} + +func (s *Session) sendBytes(p []byte) error { + select { + case s.send <- p: + case <-time.After(2 * time.Millisecond): + log.Println("websocket.sendBytes send timeout") } return nil } diff --git a/sam/websocket/outgoing/messages.go b/sam/websocket/outgoing/messages.go index 9b35af869..8d9fd5797 100644 --- a/sam/websocket/outgoing/messages.go +++ b/sam/websocket/outgoing/messages.go @@ -20,6 +20,30 @@ type ( MessageDelete struct { ID string `json:"id"` } + + ChannelJoin struct { + // ID of the channel user is joining + ID string `json:"id"` + + // ID of the user that is joining + UserID string `json:"uid"` + } + + ChannelPart struct { + // Channel to part (nil) for ALL channels + ID *string `json:"id"` + + // Who is parting + UserID string `json:"uid"` + } + + Channel struct { + // Channel to part (nil) for ALL channels + ID string `json:"id"` + Name string `json:"name"` + } + + Channels []*Channel ) func (*Message) valid() bool { return true } @@ -27,3 +51,8 @@ func (*Messages) valid() bool { return true } func (*MessageUpdate) valid() bool { return true } func (*MessageDelete) valid() bool { return true } + +func (*ChannelJoin) valid() bool { return true } +func (*ChannelPart) valid() bool { return true } +func (*Channel) valid() bool { return true } +func (*Channels) valid() bool { return true } diff --git a/sam/websocket/outgoing/payload.go b/sam/websocket/outgoing/payload.go index 6419beab7..8c9858f52 100644 --- a/sam/websocket/outgoing/payload.go +++ b/sam/websocket/outgoing/payload.go @@ -13,6 +13,11 @@ type ( *MessageUpdate `json:"mu,omitempty"` *Messages `json:"ms,omitempty"` + *ChannelJoin `json:"chj,omitempty"` + *ChannelPart `json:"chp,omitempty"` + *Channel `json:"ch,omitempty"` + *Channels `json:"chs,omitempty"` + // @todo: implement outgoing message types timestamp time.Time } @@ -33,6 +38,14 @@ func (p *Payload) Load(payload PayloadType) *Payload { p.MessageDelete = val case *MessageUpdate: p.MessageUpdate = val + case *ChannelJoin: + p.ChannelJoin = val + case *ChannelPart: + p.ChannelPart = val + case *Channel: + p.Channel = val + case *Channels: + p.Channels = val default: panic(fmt.Sprintf("Unknown/unsupported Payload type: %T", val)) } diff --git a/sam/websocket/payload.go b/sam/websocket/payload.go index 2d52d3473..0670c73e6 100644 --- a/sam/websocket/payload.go +++ b/sam/websocket/payload.go @@ -25,3 +25,19 @@ func payloadFromMessages(msg []*types.Message) *outgoing.Messages { retval := outgoing.Messages(msgs) return &retval } + +func payloadFromChannel(ch *types.Channel) *outgoing.Channel { + return &outgoing.Channel{ + ID: strconv.FormatUint(ch.ID, 10), + Name: ch.Name, + } +} + +func payloadFromChannels(channels []*types.Channel) *outgoing.Channels { + cc := make([]*outgoing.Channel, len(channels)) + for k, c := range channels { + cc[k] = payloadFromChannel(c) + } + retval := outgoing.Channels(cc) + return &retval +} diff --git a/sam/websocket/subscriptions.go b/sam/websocket/subscriptions.go index c28fc339c..199d9f5fa 100644 --- a/sam/websocket/subscriptions.go +++ b/sam/websocket/subscriptions.go @@ -25,29 +25,29 @@ func (Subscriptions) New() *Subscriptions { // @todo: load/save all subscriptions from database -func (s *Subscriptions) Add(name string, sub *Subscription) string { +func (s *Subscriptions) Add(channelID string) *Subscription { s.Lock() defer s.Unlock() - s.Subscriptions[name] = sub - return name + s.Subscriptions[channelID] = &Subscription{} + return s.Subscriptions[channelID] } -func (s *Subscriptions) Get(name string) *Subscription { +func (s *Subscriptions) Get(channelID string) *Subscription { s.RLock() defer s.RUnlock() - return s.Subscriptions[name] + return s.Subscriptions[channelID] } -func (s *Subscriptions) Delete(name string) { +func (s *Subscriptions) Delete(channelID string) { s.Lock() defer s.Unlock() - delete(s.Subscriptions, name) + delete(s.Subscriptions, channelID) } func (s *Subscriptions) DeleteAll() { s.Lock() defer s.Unlock() - for index, _ := range s.Subscriptions { + for index := range s.Subscriptions { delete(s.Subscriptions, index) } }